import itertools from collections import Counter from typing import Any, Callable, Iterable, Iterator, Sequence, TypeAlias from ..base import BaseSolver from ..tools.graphs import dijkstra, make_neighbors_grid_fn Node: TypeAlias = tuple[int, int] def make_neighbors_fn(grid: list[str], cheat_length: int): n_rows, n_cols = len(grid), len(grid[0]) def _fn(node: Node): row, col = node return ( ((row_n, col_n), abs(row_n - row) + abs(col_n - col)) for row_d in range(-cheat_length, cheat_length + 1) for col_d in range( -cheat_length + abs(row_d), cheat_length - abs(row_d) + 1 ) if 0 <= (row_n := row + row_d) < n_rows and 0 <= (col_n := col + col_d) < n_cols and grid[row_n][col_n] != "#" ) return _fn class Solver(BaseSolver): def find_cheats( self, path: Sequence[Node], cost: float, costs_to_target: dict[Node, float], neighbors_fn: Callable[[Node], Iterable[tuple[Node, float]]], ): cheats: dict[tuple[tuple[int, int], tuple[int, int]], float] = {} for i_node, node in enumerate(self.progress.wrap(path)): for reach_node, reach_cost in neighbors_fn(node): n_cost = ( i_node + reach_cost + costs_to_target.get(reach_node, float("inf")) ) if n_cost < cost: cheats[node, reach_node] = cost - n_cost return cheats def solve(self, input: str) -> Iterator[Any]: grid = input.splitlines() n_rows, n_cols = len(grid), len(grid[0]) start = next( (i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "S" ) target = next( (i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "E" ) reachable = dijkstra( target, None, make_neighbors_grid_fn( n_rows, n_cols, excluded=( (i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "#" ), ), ) # note: path is inverted here path, cost = reachable[start] costs_to_target = {k: c for k, (_, c) in reachable.items()} self.logger.info(f"found past from start to target with cost {cost}") for cheat_length in (2, 20): cheats = self.find_cheats( list(reversed(path)), cost, costs_to_target, make_neighbors_fn(grid, cheat_length), ) for saving, count in sorted(Counter(cheats.values()).items()): self.logger.debug( f"There are {count} cheats that save {saving} picoseconds." ) target_saving = 100 if len(grid) > 20 else 50 yield sum(saving >= target_saving for saving in cheats.values())