advent-of-code/2023/day17.py

215 lines
6.2 KiB
Python
Raw Normal View History

2023-12-17 08:37:31 +00:00
from __future__ import annotations
import heapq
import os
2023-12-04 18:32:41 +00:00
import sys
from collections import defaultdict
from dataclasses import dataclass
2023-12-17 08:37:31 +00:00
from typing import Literal, TypeAlias
VERBOSE = os.getenv("AOC_VERBOSE") == "True"
Direction: TypeAlias = Literal[">", "<", "^", "v"]
@dataclass(frozen=True, order=True)
class Label:
row: int
col: int
direction: Direction
parent: Label | None = None
# mappings from direction to row shift / col shift / opposite direction
MAPPINGS: dict[Direction, tuple[int, int, Direction]] = {
">": (0, +1, "<"),
"<": (0, -1, ">"),
"v": (+1, 0, "^"),
"^": (-1, 0, "v"),
}
def print_shortest_path(
grid: list[list[int]],
target: tuple[int, int],
per_cell: dict[tuple[int, int], list[tuple[Label, int]]],
):
assert len(per_cell[target]) == 1
label = per_cell[target][0][0]
path: list[Label] = []
while True:
path.insert(0, label)
if label.parent is None:
break
label = label.parent
p_grid = [[str(c) for c in r] for r in grid]
for i in range(len(grid)):
for j in range(len(grid[0])):
if per_cell[i, j]:
p_grid[i][j] = f"\033[94m{grid[i][j]}\033[0m"
prev_label = path[0]
for label in path[1:]:
for r in range(
min(prev_label.row, label.row), max(prev_label.row, label.row) + 1
):
for c in range(
min(prev_label.col, label.col),
max(prev_label.col, label.col) + 1,
):
if (r, c) != (prev_label.row, prev_label.col):
p_grid[r][c] = f"\033[93m{grid[r][c]}\033[0m"
p_grid[label.row][label.col] = f"\033[91m{grid[label.row][label.col]}\033[0m"
prev_label = label
p_grid[0][0] = f"\033[92m{grid[0][0]}\033[0m"
print("\n".join("".join(row) for row in p_grid))
def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
n_rows, n_cols = len(grid), len(grid[0])
visited: dict[tuple[int, int], tuple[Label, int]] = {}
queue: list[tuple[int, Label]] = [
2023-12-19 13:26:16 +00:00
(0, Label(row=n_rows - 1, col=n_cols - 1, direction="^"))
2023-12-17 08:37:31 +00:00
]
while queue and len(visited) != n_rows * n_cols:
distance, label = heapq.heappop(queue)
if (label.row, label.col) in visited:
continue
visited[label.row, label.col] = (label, distance)
for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
if label.direction == i_direction:
continue
else:
row, col = (label.row + c_row, label.col + c_col)
# exclude labels outside the grid or with too many moves in the same
# direction
if row not in range(0, n_rows) or col not in range(0, n_cols):
continue
heapq.heappush(
queue,
(
distance
+ sum(
grid[r][c]
for r in range(min(row, label.row), max(row, label.row) + 1)
for c in range(min(col, label.col), max(col, label.col) + 1)
)
- grid[row][col],
Label(
row=row,
col=col,
direction=direction,
parent=label,
),
),
)
return {(r, c): visited[r, c][1] for r in range(n_rows) for c in range(n_cols)}
def shortest_path(
grid: list[list[int]],
min_straight: int,
max_straight: int,
lower_bounds: dict[tuple[int, int], int],
) -> int:
n_rows, n_cols = len(grid), len(grid[0])
target = (len(grid) - 1, len(grid[0]) - 1)
# for each tuple (row, col, direction, count), the associated label when visited
2023-12-19 13:26:16 +00:00
visited: dict[tuple[int, int, Direction], Label] = {}
2023-12-17 08:37:31 +00:00
# list of all visited labels for a cell (with associated distance)
per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list)
# need to add two start labels, otherwise one of the two possible direction will
# not be possible
queue: list[tuple[int, int, Label]] = [
2023-12-19 13:26:16 +00:00
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^")),
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<")),
2023-12-17 08:37:31 +00:00
]
while queue:
_, distance, label = heapq.heappop(queue)
2023-12-19 13:26:16 +00:00
if (label.row, label.col, label.direction) in visited:
2023-12-17 08:37:31 +00:00
continue
2023-12-19 13:26:16 +00:00
visited[label.row, label.col, label.direction] = label
2023-12-17 08:37:31 +00:00
per_cell[label.row, label.col].append((label, distance))
if (label.row, label.col) == target:
break
for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
# cannot move in the opposite direction
2023-12-19 13:26:16 +00:00
if label.direction == i_direction or label.direction == direction:
2023-12-17 08:37:31 +00:00
continue
2023-12-19 13:26:16 +00:00
distance_to = distance
2023-12-17 08:37:31 +00:00
2023-12-19 13:26:16 +00:00
for amount in range(1, max_straight + 1):
row, col = (
label.row + amount * c_row,
label.col + amount * c_col,
2023-12-17 08:37:31 +00:00
)
2023-12-19 13:26:16 +00:00
# exclude labels outside the grid or with too many moves in the same
# direction
if not (0 <= row < n_rows and 0 <= col < n_cols):
break
distance_to += grid[row][col]
if amount < min_straight:
continue
heapq.heappush(
queue,
(
distance_to + lower_bounds[row, col],
distance_to,
Label(
row=row,
col=col,
direction=direction,
parent=label,
),
2023-12-17 08:37:31 +00:00
),
2023-12-19 13:26:16 +00:00
)
2023-12-17 08:37:31 +00:00
if VERBOSE:
print_shortest_path(grid, target, per_cell)
return per_cell[target][0][1]
2023-12-04 18:32:41 +00:00
2023-12-17 08:37:31 +00:00
data = [[int(c) for c in r] for r in sys.stdin.read().splitlines()]
estimates = shortest_many_paths(data)
2023-12-04 18:32:41 +00:00
# part 1
2023-12-17 08:37:31 +00:00
answer_1 = shortest_path(data, 1, 3, lower_bounds=estimates)
2023-12-04 18:32:41 +00:00
print(f"answer 1 is {answer_1}")
# part 2
2023-12-17 08:37:31 +00:00
answer_2 = shortest_path(data, 4, 10, lower_bounds=estimates)
2023-12-04 18:32:41 +00:00
print(f"answer 2 is {answer_2}")