import logging import os import sys from collections import defaultdict VERBOSE = os.getenv("AOC_VERBOSE") == "True" logging.basicConfig(level=logging.INFO if VERBOSE else logging.WARNING) Direction = { ">": (0, +1), "<": (0, -1), "^": (-1, 0), "v": (+1, 0), } Neighbors = {".": ((+1, 0), (-1, 0), (0, +1), (0, -1)), "#": ()} | { k: (v,) for k, v in Direction.items() } def neighbors( grid: list[str], node: tuple[int, int], ignore: set[tuple[int, int]] = set() ): """ Compute neighbors of the given node, ignoring the given set of nodes and considering that you can go uphill on slopes. """ i, j = node for di, dj in Neighbors[grid[i][j]]: ti, tj = di + i, dj + j if ti < 0 or ti >= n_rows or tj < 0 or tj >= n_cols: continue if (ti, tj) in ignore: continue v = grid[ti][tj] if ( v == "#" or (v == "v" and di == -1) or (v == "^" and di == 1) or (v == ">" and dj == -1) or (v == "<" and dj == 1) ): continue yield ti, tj def reachable( grid: list[str], start: tuple[int, int], target: tuple[int, int] ) -> tuple[tuple[int, int], int]: """ Compute the next 'reachable' node in the grid, starting at the given node. The next 'reachable' node is the first node after a slope in the path starting from the given node (not going uphill). """ distance, path = 0, {start} while True: i, j = start if (i, j) == target: return (target, distance) if grid[i][j] != ".": di, dj = Direction[grid[i][j]] return ((i + di, j + dj), distance + 1) start = next(neighbors(grid, start, path := path | {(i, j)})) distance += 1 def compute_direct_links( grid: list[str], start: tuple[int, int], target: tuple[int, int] ) -> dict[tuple[int, int], list[tuple[tuple[int, int], int]]]: if start == target: return {} direct: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = {start: []} for neighbor in neighbors(grid, start): i, j = neighbor di, dj = Direction[grid[i][j]] reach, distance = reachable(lines, (i + di, j + dj), target) direct[start].append((reach, distance + 2)) direct.update(compute_direct_links(grid, reach, target)) return direct def longest_path_length( links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]], start: tuple[int, int], target: tuple[int, int], ) -> int: max_distance: int = -1 queue: list[tuple[tuple[int, int], int, frozenset[tuple[int, int]]]] = [ (start, 0, frozenset({start})) ] nodes = 0 while queue: node, distance, path = queue.pop() nodes += 1 if node == target: max_distance = max(distance, max_distance) continue queue.extend( (reach, distance + length, path | {reach}) for reach, length in links.get(node, []) if reach not in path ) logging.info(f"processed {nodes} nodes") return max_distance lines = sys.stdin.read().splitlines() n_rows, n_cols = len(lines), len(lines[0]) start = (0, 1) target = (len(lines) - 1, len(lines[0]) - 2) direct_links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = { start: [reachable(lines, start, target)] } direct_links.update(compute_direct_links(lines, direct_links[start][0][0], target)) # part 1 answer_1 = longest_path_length(direct_links, start, target) print(f"answer 1 is {answer_1}") # part 2 reverse_links: dict[tuple[int, int], list[tuple[tuple[int, int], int]]] = defaultdict( list ) for origin, links in direct_links.items(): for destination, distance in links: if origin != start: reverse_links[destination].append((origin, distance)) links = { k: direct_links.get(k, []) + reverse_links.get(k, []) for k in direct_links.keys() | reverse_links.keys() } answer_2 = longest_path_length(links, start, target) print(f"answer 2 is {answer_2}")