advent-of-code/src/holt59/aoc/2024/day20.py

96 lines
3.0 KiB
Python
Raw Normal View History

2024-12-20 08:19:12 +00:00
import itertools
from collections import Counter
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeAlias
2024-12-01 09:26:02 +00:00
from ..base import BaseSolver
2024-12-20 08:19:12 +00:00
from ..tools.graphs import dijkstra, make_neighbors_grid_fn
Node: TypeAlias = tuple[int, int]
def make_neighbors_fn(grid: list[str], cheat_length: int):
n_rows, n_cols = len(grid), len(grid[0])
def _fn(node: Node):
row, col = node
return (
((row_n, col_n), abs(row_n - row) + abs(col_n - col))
for row_d in range(-cheat_length, cheat_length + 1)
for col_d in range(
-cheat_length + abs(row_d), cheat_length - abs(row_d) + 1
)
if 0 <= (row_n := row + row_d) < n_rows
and 0 <= (col_n := col + col_d) < n_cols
and grid[row_n][col_n] != "#"
)
return _fn
2024-12-01 09:26:02 +00:00
class Solver(BaseSolver):
2024-12-20 08:19:12 +00:00
def find_cheats(
self,
path: Sequence[Node],
cost: float,
costs_to_target: dict[Node, float],
neighbors_fn: Callable[[Node], Iterable[tuple[Node, float]]],
):
cheats: dict[tuple[tuple[int, int], tuple[int, int]], float] = {}
for i_node, node in enumerate(self.progress.wrap(path)):
for reach_node, reach_cost in neighbors_fn(node):
n_cost = (
i_node + reach_cost + costs_to_target.get(reach_node, float("inf"))
)
if n_cost < cost:
cheats[node, reach_node] = cost - n_cost
return cheats
def solve(self, input: str) -> Iterator[Any]:
grid = input.splitlines()
n_rows, n_cols = len(grid), len(grid[0])
start = next(
(i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "S"
)
target = next(
(i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "E"
)
reachable = dijkstra(
target,
None,
make_neighbors_grid_fn(
n_rows,
n_cols,
excluded=(
(i, j)
for i in range(n_rows)
for j in range(n_cols)
if grid[i][j] == "#"
),
),
)
# note: path is inverted here
path, cost = reachable[start]
costs_to_target = {k: c for k, (_, c) in reachable.items()}
self.logger.info(f"found past from start to target with cost {cost}")
for cheat_length in (2, 20):
cheats = self.find_cheats(
list(reversed(path)),
cost,
costs_to_target,
make_neighbors_fn(grid, cheat_length),
)
for saving, count in sorted(Counter(cheats.values()).items()):
self.logger.debug(
f"There are {count} cheats that save {saving} picoseconds."
)
target_saving = 100 if len(grid) > 20 else 50
yield sum(saving >= target_saving for saving in cheats.values())