import sys from math import prod values = [[int(c) for c in row] for row in sys.stdin.read().splitlines()] n_rows, n_cols = len(values), len(values[0]) def neighbors(point: tuple[int, int]): i, j = point for di, dj in ((-1, 0), (+1, 0), (0, -1), (0, +1)): if 0 <= i + di < n_rows and 0 <= j + dj < n_cols: yield (i + di, j + dj) def basin(start: tuple[int, int]) -> set[tuple[int, int]]: visited: set[tuple[int, int]] = set() queue = [start] while queue: i, j = queue.pop() if (i, j) in visited or values[i][j] == 9: continue visited.add((i, j)) queue.extend(neighbors((i, j))) return visited low_points = [ (i, j) for i in range(n_rows) for j in range(n_cols) if all(values[ti][tj] > values[i][j] for ti, tj in neighbors((i, j))) ] # part 1 answer_1 = sum(values[i][j] + 1 for i, j in low_points) print(f"answer 1 is {answer_1}") # part 2 answer_2 = prod(sorted(len(basin(point)) for point in low_points)[-3:]) print(f"answer 2 is {answer_2}")