import itertools as it from typing import Any, Iterator from ..base import BaseSolver def process( grid: list[list[int]], current: tuple[int, int] ) -> set[tuple[tuple[int, int], ...]]: row, col = current value = grid[row][col] + 1 if grid[row][col] == 9: return {((row, col),)} return { ((row, col),) + path for i, j in ((row - 1, col), (row, col + 1), (row + 1, col), (row, col - 1)) if 0 <= i < len(grid) and 0 <= j < len(grid[i]) and grid[i][j] == value for path in process(grid, (i, j)) } class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: grid = [[int(col) for col in row] for row in input.splitlines()] paths = { (i, j): process(grid, (i, j)) for i, j in it.product(range(len(grid)), range(len(grid[0]))) if grid[i][j] == 0 } yield sum(len({path[-1] for path in paths[head]}) for head in paths) yield sum(len(paths_of) for paths_of in paths.values())