96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
import itertools
|
|
from collections import Counter
|
|
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeAlias
|
|
|
|
from ..base import BaseSolver
|
|
from ..tools.graphs import dijkstra, make_neighbors_grid_fn
|
|
|
|
Node: TypeAlias = tuple[int, int]
|
|
|
|
|
|
def make_neighbors_fn(grid: list[str], cheat_length: int):
|
|
n_rows, n_cols = len(grid), len(grid[0])
|
|
|
|
def _fn(node: Node):
|
|
row, col = node
|
|
return (
|
|
((row_n, col_n), abs(row_n - row) + abs(col_n - col))
|
|
for row_d in range(-cheat_length, cheat_length + 1)
|
|
for col_d in range(
|
|
-cheat_length + abs(row_d), cheat_length - abs(row_d) + 1
|
|
)
|
|
if 0 <= (row_n := row + row_d) < n_rows
|
|
and 0 <= (col_n := col + col_d) < n_cols
|
|
and grid[row_n][col_n] != "#"
|
|
)
|
|
|
|
return _fn
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
def find_cheats(
|
|
self,
|
|
path: Sequence[Node],
|
|
cost: float,
|
|
costs_to_target: dict[Node, float],
|
|
neighbors_fn: Callable[[Node], Iterable[tuple[Node, float]]],
|
|
):
|
|
cheats: dict[tuple[tuple[int, int], tuple[int, int]], float] = {}
|
|
|
|
for i_node, node in enumerate(self.progress.wrap(path)):
|
|
for reach_node, reach_cost in neighbors_fn(node):
|
|
n_cost = (
|
|
i_node + reach_cost + costs_to_target.get(reach_node, float("inf"))
|
|
)
|
|
if n_cost < cost:
|
|
cheats[node, reach_node] = cost - n_cost
|
|
|
|
return cheats
|
|
|
|
def solve(self, input: str) -> Iterator[Any]:
|
|
grid = input.splitlines()
|
|
n_rows, n_cols = len(grid), len(grid[0])
|
|
start = next(
|
|
(i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "S"
|
|
)
|
|
target = next(
|
|
(i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "E"
|
|
)
|
|
|
|
reachable = dijkstra(
|
|
target,
|
|
None,
|
|
make_neighbors_grid_fn(
|
|
n_rows,
|
|
n_cols,
|
|
excluded=(
|
|
(i, j)
|
|
for i in range(n_rows)
|
|
for j in range(n_cols)
|
|
if grid[i][j] == "#"
|
|
),
|
|
),
|
|
)
|
|
|
|
# note: path is inverted here
|
|
path, cost = reachable[start]
|
|
costs_to_target = {k: c for k, (_, c) in reachable.items()}
|
|
|
|
self.logger.info(f"found past from start to target with cost {cost}")
|
|
|
|
for cheat_length in (2, 20):
|
|
cheats = self.find_cheats(
|
|
list(reversed(path)),
|
|
cost,
|
|
costs_to_target,
|
|
make_neighbors_fn(grid, cheat_length),
|
|
)
|
|
|
|
for saving, count in sorted(Counter(cheats.values()).items()):
|
|
self.logger.debug(
|
|
f"There are {count} cheats that save {saving} picoseconds."
|
|
)
|
|
|
|
target_saving = 100 if len(grid) > 20 else 50
|
|
yield sum(saving >= target_saving for saving in cheats.values())
|