This commit is contained in:
@@ -1,7 +1,95 @@
|
||||
from typing import Any, Iterator
|
||||
import itertools
|
||||
from collections import Counter
|
||||
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeAlias
|
||||
|
||||
from ..base import BaseSolver
|
||||
from ..tools.graphs import dijkstra, make_neighbors_grid_fn
|
||||
|
||||
Node: TypeAlias = tuple[int, int]
|
||||
|
||||
|
||||
def make_neighbors_fn(grid: list[str], cheat_length: int):
|
||||
n_rows, n_cols = len(grid), len(grid[0])
|
||||
|
||||
def _fn(node: Node):
|
||||
row, col = node
|
||||
return (
|
||||
((row_n, col_n), abs(row_n - row) + abs(col_n - col))
|
||||
for row_d in range(-cheat_length, cheat_length + 1)
|
||||
for col_d in range(
|
||||
-cheat_length + abs(row_d), cheat_length - abs(row_d) + 1
|
||||
)
|
||||
if 0 <= (row_n := row + row_d) < n_rows
|
||||
and 0 <= (col_n := col + col_d) < n_cols
|
||||
and grid[row_n][col_n] != "#"
|
||||
)
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
class Solver(BaseSolver):
|
||||
def solve(self, input: str) -> Iterator[Any]: ...
|
||||
def find_cheats(
|
||||
self,
|
||||
path: Sequence[Node],
|
||||
cost: float,
|
||||
costs_to_target: dict[Node, float],
|
||||
neighbors_fn: Callable[[Node], Iterable[tuple[Node, float]]],
|
||||
):
|
||||
cheats: dict[tuple[tuple[int, int], tuple[int, int]], float] = {}
|
||||
|
||||
for i_node, node in enumerate(self.progress.wrap(path)):
|
||||
for reach_node, reach_cost in neighbors_fn(node):
|
||||
n_cost = (
|
||||
i_node + reach_cost + costs_to_target.get(reach_node, float("inf"))
|
||||
)
|
||||
if n_cost < cost:
|
||||
cheats[node, reach_node] = cost - n_cost
|
||||
|
||||
return cheats
|
||||
|
||||
def solve(self, input: str) -> Iterator[Any]:
|
||||
grid = input.splitlines()
|
||||
n_rows, n_cols = len(grid), len(grid[0])
|
||||
start = next(
|
||||
(i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "S"
|
||||
)
|
||||
target = next(
|
||||
(i, j) for i in range(n_rows) for j in range(n_cols) if grid[i][j] == "E"
|
||||
)
|
||||
|
||||
reachable = dijkstra(
|
||||
target,
|
||||
None,
|
||||
make_neighbors_grid_fn(
|
||||
n_rows,
|
||||
n_cols,
|
||||
excluded=(
|
||||
(i, j)
|
||||
for i in range(n_rows)
|
||||
for j in range(n_cols)
|
||||
if grid[i][j] == "#"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# note: path is inverted here
|
||||
path, cost = reachable[start]
|
||||
costs_to_target = {k: c for k, (_, c) in reachable.items()}
|
||||
|
||||
self.logger.info(f"found past from start to target with cost {cost}")
|
||||
|
||||
for cheat_length in (2, 20):
|
||||
cheats = self.find_cheats(
|
||||
list(reversed(path)),
|
||||
cost,
|
||||
costs_to_target,
|
||||
make_neighbors_fn(grid, cheat_length),
|
||||
)
|
||||
|
||||
for saving, count in sorted(Counter(cheats.values()).items()):
|
||||
self.logger.debug(
|
||||
f"There are {count} cheats that save {saving} picoseconds."
|
||||
)
|
||||
|
||||
target_saving = 100 if len(grid) > 20 else 50
|
||||
yield sum(saving >= target_saving for saving in cheats.values())
|
||||
|
Reference in New Issue
Block a user