Mikaël Capelle ce315b8778 Refactor code for API (#3)
Co-authored-by: Mikael CAPELLE <mikael.capelle@thalesaleniaspace.com>
Co-authored-by: Mikaël Capelle <capelle.mikael@gmail.com>
Reviewed-on: #3
2024-12-08 13:06:41 +00:00

166 lines
4.6 KiB
Python

from collections import defaultdict
from typing import Any, Iterator, Literal, Sequence, TypeAlias, cast
from ..base import BaseSolver
DirectionType: TypeAlias = Literal[">", "<", "^", "v", ".", "#"]
Direction: dict[DirectionType, tuple[int, int]] = {
">": (0, +1),
"<": (0, -1),
"^": (-1, 0),
"v": (+1, 0),
}
Neighbors = cast(
"dict[DirectionType, tuple[tuple[int, int], ...]]",
{
".": ((+1, 0), (-1, 0), (0, +1), (0, -1)),
"#": (),
}
| {k: (v,) for k, v in Direction.items()},
)
def neighbors(
grid: list[Sequence[DirectionType]],
node: tuple[int, int],
ignore: set[tuple[int, int]] = set(),
):
"""
Compute neighbors of the given node, ignoring the given set of nodes and considering
that you can go uphill on slopes.
"""
n_rows, n_cols = len(grid), len(grid[0])
i, j = node
for di, dj in Neighbors[grid[i][j]]:
ti, tj = di + i, dj + j
if ti < 0 or ti >= n_rows or tj < 0 or tj >= n_cols:
continue
if (ti, tj) in ignore:
continue
v = grid[ti][tj]
if (
v == "#"
or (v == "v" and di == -1)
or (v == "^" and di == 1)
or (v == ">" and dj == -1)
or (v == "<" and dj == 1)
):
continue
yield ti, tj
def reachable(
grid: list[Sequence[DirectionType]], start: tuple[int, int], target: tuple[int, int]
) -> tuple[tuple[int, int], int]:
"""
Compute the next 'reachable' node in the grid, starting at the given node.
The next 'reachable' node is the first node after a slope in the path starting from
the given node (not going uphill).
"""
distance, path = 0, {start}
while True:
i, j = start
if (i, j) == target:
return (target, distance)
if grid[i][j] != ".":
di, dj = Direction[grid[i][j]]
return ((i + di, j + dj), distance + 1)
start = next(neighbors(grid, start, path := path | {(i, j)}))
distance += 1
def compute_direct_links(
grid: list[Sequence[DirectionType]], start: tuple[int, int], target: tuple[int, int]
) -> dict[tuple[int, int], list[tuple[tuple[int, int], int]]]:
if start == target:
return {}
direct: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = {start: []}
for neighbor in neighbors(grid, start):
i, j = neighbor
di, dj = Direction[grid[i][j]]
reach, distance = reachable(grid, (i + di, j + dj), target)
direct[start].append((reach, distance + 2))
direct.update(compute_direct_links(grid, reach, target))
return direct
class Solver(BaseSolver):
def longest_path_length(
self,
links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]],
start: tuple[int, int],
target: tuple[int, int],
) -> int:
max_distance: int = -1
queue: list[tuple[tuple[int, int], int, frozenset[tuple[int, int]]]] = [
(start, 0, frozenset({start}))
]
nodes = 0
while queue:
node, distance, path = queue.pop()
nodes += 1
if node == target:
max_distance = max(distance, max_distance)
continue
queue.extend(
(reach, distance + length, path | {reach})
for reach, length in links.get(node, [])
if reach not in path
)
self.logger.info(f"processed {nodes} nodes")
return max_distance
def solve(self, input: str) -> Iterator[Any]:
lines = cast(list[Sequence[DirectionType]], input.splitlines())
start = (0, 1)
target = (len(lines) - 1, len(lines[0]) - 2)
direct_links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = {
start: [reachable(lines, start, target)]
}
direct_links.update(
compute_direct_links(lines, direct_links[start][0][0], target)
)
# part 1
yield self.longest_path_length(direct_links, start, target)
# part 2
reverse_links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = (
defaultdict(list)
)
for origin, links in direct_links.items():
for destination, distance in links:
if origin != start:
reverse_links[destination].append((origin, distance))
links = {
k: direct_links.get(k, []) + reverse_links.get(k, [])
for k in direct_links.keys() | reverse_links.keys()
}
yield self.longest_path_length(links, start, target)