Co-authored-by: Mikael CAPELLE <mikael.capelle@thalesaleniaspace.com> Co-authored-by: Mikaël Capelle <capelle.mikael@gmail.com> Reviewed-on: #3
48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
from math import prod
|
|
from typing import Any, Iterator
|
|
|
|
from ..base import BaseSolver
|
|
|
|
|
|
def neighbors(point: tuple[int, int], n_rows: int, n_cols: int):
|
|
i, j = point
|
|
for di, dj in ((-1, 0), (+1, 0), (0, -1), (0, +1)):
|
|
if 0 <= i + di < n_rows and 0 <= j + dj < n_cols:
|
|
yield (i + di, j + dj)
|
|
|
|
|
|
def basin(values: list[list[int]], start: tuple[int, int]) -> set[tuple[int, int]]:
|
|
n_rows, n_cols = len(values), len(values[0])
|
|
visited: set[tuple[int, int]] = set()
|
|
queue = [start]
|
|
|
|
while queue:
|
|
i, j = queue.pop()
|
|
|
|
if (i, j) in visited or values[i][j] == 9:
|
|
continue
|
|
|
|
visited.add((i, j))
|
|
queue.extend(neighbors((i, j), n_rows, n_cols))
|
|
|
|
return visited
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
def solve(self, input: str) -> Iterator[Any]:
|
|
values = [[int(c) for c in row] for row in input.splitlines()]
|
|
n_rows, n_cols = len(values), len(values[0])
|
|
|
|
low_points = [
|
|
(i, j)
|
|
for i in range(n_rows)
|
|
for j in range(n_cols)
|
|
if all(
|
|
values[ti][tj] > values[i][j]
|
|
for ti, tj in neighbors((i, j), n_rows, n_cols)
|
|
)
|
|
]
|
|
|
|
yield sum(values[i][j] + 1 for i, j in low_points)
|
|
yield prod(sorted(len(basin(values, point)) for point in low_points)[-3:])
|