from math import prod from typing import Any, Iterator from ..base import BaseSolver def neighbors(point: tuple[int, int], n_rows: int, n_cols: int): i, j = point for di, dj in ((-1, 0), (+1, 0), (0, -1), (0, +1)): if 0 <= i + di < n_rows and 0 <= j + dj < n_cols: yield (i + di, j + dj) def basin(values: list[list[int]], start: tuple[int, int]) -> set[tuple[int, int]]: n_rows, n_cols = len(values), len(values[0]) visited: set[tuple[int, int]] = set() queue = [start] while queue: i, j = queue.pop() if (i, j) in visited or values[i][j] == 9: continue visited.add((i, j)) queue.extend(neighbors((i, j), n_rows, n_cols)) return visited class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: values = [[int(c) for c in row] for row in input.splitlines()] n_rows, n_cols = len(values), len(values[0]) low_points = [ (i, j) for i in range(n_rows) for j in range(n_cols) if all( values[ti][tj] > values[i][j] for ti, tj in neighbors((i, j), n_rows, n_cols) ) ] yield sum(values[i][j] + 1 for i, j in low_points) yield prod(sorted(len(basin(values, point)) for point in low_points)[-3:])