from collections import defaultdict from typing import Any, Iterator, Literal, Sequence, TypeAlias, cast from ..base import BaseSolver DirectionType: TypeAlias = Literal[">", "<", "^", "v", ".", "#"] Direction: dict[DirectionType, tuple[int, int]] = { ">": (0, +1), "<": (0, -1), "^": (-1, 0), "v": (+1, 0), } Neighbors = cast( "dict[DirectionType, tuple[tuple[int, int], ...]]", { ".": ((+1, 0), (-1, 0), (0, +1), (0, -1)), "#": (), } | {k: (v,) for k, v in Direction.items()}, ) def neighbors( grid: list[Sequence[DirectionType]], node: tuple[int, int], ignore: set[tuple[int, int]] = set(), ): """ Compute neighbors of the given node, ignoring the given set of nodes and considering that you can go uphill on slopes. """ n_rows, n_cols = len(grid), len(grid[0]) i, j = node for di, dj in Neighbors[grid[i][j]]: ti, tj = di + i, dj + j if ti < 0 or ti >= n_rows or tj < 0 or tj >= n_cols: continue if (ti, tj) in ignore: continue v = grid[ti][tj] if ( v == "#" or (v == "v" and di == -1) or (v == "^" and di == 1) or (v == ">" and dj == -1) or (v == "<" and dj == 1) ): continue yield ti, tj def reachable( grid: list[Sequence[DirectionType]], start: tuple[int, int], target: tuple[int, int] ) -> tuple[tuple[int, int], int]: """ Compute the next 'reachable' node in the grid, starting at the given node. The next 'reachable' node is the first node after a slope in the path starting from the given node (not going uphill). """ distance, path = 0, {start} while True: i, j = start if (i, j) == target: return (target, distance) if grid[i][j] != ".": di, dj = Direction[grid[i][j]] return ((i + di, j + dj), distance + 1) start = next(neighbors(grid, start, path := path | {(i, j)})) distance += 1 def compute_direct_links( grid: list[Sequence[DirectionType]], start: tuple[int, int], target: tuple[int, int] ) -> dict[tuple[int, int], list[tuple[tuple[int, int], int]]]: if start == target: return {} direct: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = {start: []} for neighbor in neighbors(grid, start): i, j = neighbor di, dj = Direction[grid[i][j]] reach, distance = reachable(grid, (i + di, j + dj), target) direct[start].append((reach, distance + 2)) direct.update(compute_direct_links(grid, reach, target)) return direct class Solver(BaseSolver): def longest_path_length( self, links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]], start: tuple[int, int], target: tuple[int, int], ) -> int: max_distance: int = -1 queue: list[tuple[tuple[int, int], int, frozenset[tuple[int, int]]]] = [ (start, 0, frozenset({start})) ] nodes = 0 while queue: node, distance, path = queue.pop() nodes += 1 if node == target: max_distance = max(distance, max_distance) continue queue.extend( (reach, distance + length, path | {reach}) for reach, length in links.get(node, []) if reach not in path ) self.logger.info(f"processed {nodes} nodes") return max_distance def solve(self, input: str) -> Iterator[Any]: lines = cast(list[Sequence[DirectionType]], input.splitlines()) start = (0, 1) target = (len(lines) - 1, len(lines[0]) - 2) direct_links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = { start: [reachable(lines, start, target)] } direct_links.update( compute_direct_links(lines, direct_links[start][0][0], target) ) # part 1 yield self.longest_path_length(direct_links, start, target) # part 2 reverse_links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = ( defaultdict(list) ) for origin, links in direct_links.items(): for destination, distance in links: if origin != start: reverse_links[destination].append((origin, distance)) links = { k: direct_links.get(k, []) + reverse_links.get(k, []) for k in direct_links.keys() | reverse_links.keys() } yield self.longest_path_length(links, start, target)