from typing import Any, Iterator from ..base import BaseSolver def process( grid: list[list[int]], current: tuple[int, int] ) -> set[tuple[tuple[int, int], ...]]: row, col = current if grid[row][col] == 9: return {((row, col),)} result: set[tuple[tuple[int, int], ...]] = set() for di, dj in ((-1, 0), (0, 1), (1, 0), (0, -1)): i, j = row + di, col + dj if i < 0 or i >= len(grid) or j < 0 or j >= len(grid[row]): continue if grid[i][j] != grid[row][col] + 1: continue for path in process(grid, (i, j)): result.add(((row, col),) + path) return result class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: grid = [[int(col) for col in row] for row in input.splitlines()] n_rows, n_cols = len(grid), len(grid[0]) heads = [ (i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == 0 ] # for head in heads: # print(head, process(grid, head)) paths = {head: process(grid, head) for head in heads} yield sum(len({path[-1] for path in paths[head]}) for head in heads) yield sum(len(paths[head]) for head in heads)