diff --git a/src/holt59/aoc/2024/day10.py b/src/holt59/aoc/2024/day10.py index 90ef37f..2f8a4e7 100644 --- a/src/holt59/aoc/2024/day10.py +++ b/src/holt59/aoc/2024/day10.py @@ -1,3 +1,4 @@ +import itertools as it from typing import Any, Iterator from ..base import BaseSolver @@ -7,40 +8,28 @@ def process( grid: list[list[int]], current: tuple[int, int] ) -> set[tuple[tuple[int, int], ...]]: row, col = current + value = grid[row][col] + 1 if grid[row][col] == 9: return {((row, col),)} - result: set[tuple[tuple[int, int], ...]] = set() - - for di, dj in ((-1, 0), (0, 1), (1, 0), (0, -1)): - i, j = row + di, col + dj - - if i < 0 or i >= len(grid) or j < 0 or j >= len(grid[row]): - continue - - if grid[i][j] != grid[row][col] + 1: - continue - - for path in process(grid, (i, j)): - result.add(((row, col),) + path) - - return result + return { + ((row, col),) + path + for i, j in ((row - 1, col), (row, col + 1), (row + 1, col), (row, col - 1)) + if 0 <= i < len(grid) and 0 <= j < len(grid[i]) and grid[i][j] == value + for path in process(grid, (i, j)) + } class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: grid = [[int(col) for col in row] for row in input.splitlines()] - n_rows, n_cols = len(grid), len(grid[0]) - heads = [ - (i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == 0 - ] + paths = { + (i, j): process(grid, (i, j)) + for i, j in it.product(range(len(grid)), range(len(grid[0]))) + if grid[i][j] == 0 + } - # for head in heads: - # print(head, process(grid, head)) - - paths = {head: process(grid, head) for head in heads} - - yield sum(len({path[-1] for path in paths[head]}) for head in heads) - yield sum(len(paths[head]) for head in heads) + yield sum(len({path[-1] for path in paths[head]}) for head in paths) + yield sum(len(paths_of) for paths_of in paths.values())