from __future__ import annotations import heapq import os import sys from collections import defaultdict from dataclasses import dataclass from typing import Literal, TypeAlias VERBOSE = os.getenv("AOC_VERBOSE") == "True" Direction: TypeAlias = Literal[">", "<", "^", "v"] @dataclass(frozen=True, order=True) class Label: row: int col: int direction: Direction count: int parent: Label | None = None # mappings from direction to row shift / col shift / opposite direction MAPPINGS: dict[Direction, tuple[int, int, Direction]] = { ">": (0, +1, "<"), "<": (0, -1, ">"), "v": (+1, 0, "^"), "^": (-1, 0, "v"), } def print_shortest_path( grid: list[list[int]], target: tuple[int, int], per_cell: dict[tuple[int, int], list[tuple[Label, int]]], ): assert len(per_cell[target]) == 1 label = per_cell[target][0][0] path: list[Label] = [] while True: path.insert(0, label) if label.parent is None: break label = label.parent p_grid = [[str(c) for c in r] for r in grid] for i in range(len(grid)): for j in range(len(grid[0])): if per_cell[i, j]: p_grid[i][j] = f"\033[94m{grid[i][j]}\033[0m" prev_label = path[0] for label in path[1:]: for r in range( min(prev_label.row, label.row), max(prev_label.row, label.row) + 1 ): for c in range( min(prev_label.col, label.col), max(prev_label.col, label.col) + 1, ): if (r, c) != (prev_label.row, prev_label.col): p_grid[r][c] = f"\033[93m{grid[r][c]}\033[0m" p_grid[label.row][label.col] = f"\033[91m{grid[label.row][label.col]}\033[0m" prev_label = label p_grid[0][0] = f"\033[92m{grid[0][0]}\033[0m" print("\n".join("".join(row) for row in p_grid)) def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]: n_rows, n_cols = len(grid), len(grid[0]) visited: dict[tuple[int, int], tuple[Label, int]] = {} queue: list[tuple[int, Label]] = [ (0, Label(row=n_rows - 1, col=n_cols - 1, direction="^", count=0)) ] while queue and len(visited) != n_rows * n_cols: distance, label = heapq.heappop(queue) if (label.row, label.col) in visited: continue visited[label.row, label.col] = (label, distance) for direction, (c_row, c_col, i_direction) in MAPPINGS.items(): if label.direction == i_direction: continue else: row, col = (label.row + c_row, label.col + c_col) # exclude labels outside the grid or with too many moves in the same # direction if row not in range(0, n_rows) or col not in range(0, n_cols): continue heapq.heappush( queue, ( distance + sum( grid[r][c] for r in range(min(row, label.row), max(row, label.row) + 1) for c in range(min(col, label.col), max(col, label.col) + 1) ) - grid[row][col], Label( row=row, col=col, direction=direction, count=0, parent=label, ), ), ) return {(r, c): visited[r, c][1] for r in range(n_rows) for c in range(n_cols)} def shortest_path( grid: list[list[int]], min_straight: int, max_straight: int, lower_bounds: dict[tuple[int, int], int], ) -> int: n_rows, n_cols = len(grid), len(grid[0]) target = (len(grid) - 1, len(grid[0]) - 1) # for each tuple (row, col, direction, count), the associated label when visited visited: dict[tuple[int, int, str, int], Label] = {} # list of all visited labels for a cell (with associated distance) per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list) # need to add two start labels, otherwise one of the two possible direction will # not be possible queue: list[tuple[int, int, Label]] = [ (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^", count=0)), (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<", count=0)), ] while queue: _, distance, label = heapq.heappop(queue) if (label.row, label.col, label.direction, label.count) in visited: continue visited[label.row, label.col, label.direction, label.count] = label per_cell[label.row, label.col].append((label, distance)) if (label.row, label.col) == target: break for direction, (c_row, c_col, i_direction) in MAPPINGS.items(): # cannot move in the opposite direction if label.direction == i_direction: continue # other direction, move 'min_straight' in the new direction elif label.direction != direction: row, col, count = ( label.row + min_straight * c_row, label.col + min_straight * c_col, min_straight, ) # same direction, too many count elif label.count == max_straight: continue # same direction, keep going and increment count else: row, col, count = ( label.row + c_row, label.col + c_col, label.count + 1, ) # exclude labels outside the grid or with too many moves in the same # direction if row not in range(0, n_rows) or col not in range(0, n_cols): continue distance_to = ( distance + sum( grid[r][c] for r in range(min(row, label.row), max(row, label.row) + 1) for c in range(min(col, label.col), max(col, label.col) + 1) ) - grid[label.row][label.col] ) heapq.heappush( queue, ( distance_to + lower_bounds[row, col], distance_to, Label( row=row, col=col, direction=direction, count=count, parent=label, ), ), ) if VERBOSE: print_shortest_path(grid, target, per_cell) return per_cell[target][0][1] data = [[int(c) for c in r] for r in sys.stdin.read().splitlines()] estimates = shortest_many_paths(data) # part 1 answer_1 = shortest_path(data, 1, 3, lower_bounds=estimates) print(f"answer 1 is {answer_1}") # part 2 answer_2 = shortest_path(data, 4, 10, lower_bounds=estimates) print(f"answer 2 is {answer_2}")