234 lines
7.0 KiB
Python
234 lines
7.0 KiB
Python
from __future__ import annotations
|
|
|
|
import heapq
|
|
import os
|
|
import sys
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Literal, TypeAlias
|
|
|
|
VERBOSE = os.getenv("AOC_VERBOSE") == "True"
|
|
|
|
Direction: TypeAlias = Literal[">", "<", "^", "v"]
|
|
|
|
|
|
@dataclass(frozen=True, order=True)
|
|
class Label:
|
|
row: int
|
|
col: int
|
|
|
|
direction: Direction
|
|
count: int
|
|
|
|
parent: Label | None = None
|
|
|
|
|
|
# mappings from direction to row shift / col shift / opposite direction
|
|
MAPPINGS: dict[Direction, tuple[int, int, Direction]] = {
|
|
">": (0, +1, "<"),
|
|
"<": (0, -1, ">"),
|
|
"v": (+1, 0, "^"),
|
|
"^": (-1, 0, "v"),
|
|
}
|
|
|
|
|
|
def print_shortest_path(
|
|
grid: list[list[int]],
|
|
target: tuple[int, int],
|
|
per_cell: dict[tuple[int, int], list[tuple[Label, int]]],
|
|
):
|
|
assert len(per_cell[target]) == 1
|
|
label = per_cell[target][0][0]
|
|
|
|
path: list[Label] = []
|
|
while True:
|
|
path.insert(0, label)
|
|
if label.parent is None:
|
|
break
|
|
label = label.parent
|
|
|
|
p_grid = [[str(c) for c in r] for r in grid]
|
|
|
|
for i in range(len(grid)):
|
|
for j in range(len(grid[0])):
|
|
if per_cell[i, j]:
|
|
p_grid[i][j] = f"\033[94m{grid[i][j]}\033[0m"
|
|
|
|
prev_label = path[0]
|
|
for label in path[1:]:
|
|
for r in range(
|
|
min(prev_label.row, label.row), max(prev_label.row, label.row) + 1
|
|
):
|
|
for c in range(
|
|
min(prev_label.col, label.col),
|
|
max(prev_label.col, label.col) + 1,
|
|
):
|
|
if (r, c) != (prev_label.row, prev_label.col):
|
|
p_grid[r][c] = f"\033[93m{grid[r][c]}\033[0m"
|
|
|
|
p_grid[label.row][label.col] = f"\033[91m{grid[label.row][label.col]}\033[0m"
|
|
|
|
prev_label = label
|
|
|
|
p_grid[0][0] = f"\033[92m{grid[0][0]}\033[0m"
|
|
|
|
print("\n".join("".join(row) for row in p_grid))
|
|
|
|
|
|
def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
|
|
n_rows, n_cols = len(grid), len(grid[0])
|
|
|
|
visited: dict[tuple[int, int], tuple[Label, int]] = {}
|
|
|
|
queue: list[tuple[int, Label]] = [
|
|
(0, Label(row=n_rows - 1, col=n_cols - 1, direction="^", count=0))
|
|
]
|
|
|
|
while queue and len(visited) != n_rows * n_cols:
|
|
distance, label = heapq.heappop(queue)
|
|
|
|
if (label.row, label.col) in visited:
|
|
continue
|
|
|
|
visited[label.row, label.col] = (label, distance)
|
|
|
|
for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
|
|
if label.direction == i_direction:
|
|
continue
|
|
else:
|
|
row, col = (label.row + c_row, label.col + c_col)
|
|
|
|
# exclude labels outside the grid or with too many moves in the same
|
|
# direction
|
|
if row not in range(0, n_rows) or col not in range(0, n_cols):
|
|
continue
|
|
|
|
heapq.heappush(
|
|
queue,
|
|
(
|
|
distance
|
|
+ sum(
|
|
grid[r][c]
|
|
for r in range(min(row, label.row), max(row, label.row) + 1)
|
|
for c in range(min(col, label.col), max(col, label.col) + 1)
|
|
)
|
|
- grid[row][col],
|
|
Label(
|
|
row=row,
|
|
col=col,
|
|
direction=direction,
|
|
count=0,
|
|
parent=label,
|
|
),
|
|
),
|
|
)
|
|
|
|
return {(r, c): visited[r, c][1] for r in range(n_rows) for c in range(n_cols)}
|
|
|
|
|
|
def shortest_path(
|
|
grid: list[list[int]],
|
|
min_straight: int,
|
|
max_straight: int,
|
|
lower_bounds: dict[tuple[int, int], int],
|
|
) -> int:
|
|
n_rows, n_cols = len(grid), len(grid[0])
|
|
|
|
target = (len(grid) - 1, len(grid[0]) - 1)
|
|
|
|
# for each tuple (row, col, direction, count), the associated label when visited
|
|
visited: dict[tuple[int, int, str, int], Label] = {}
|
|
|
|
# list of all visited labels for a cell (with associated distance)
|
|
per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list)
|
|
|
|
# need to add two start labels, otherwise one of the two possible direction will
|
|
# not be possible
|
|
queue: list[tuple[int, int, Label]] = [
|
|
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^", count=0)),
|
|
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<", count=0)),
|
|
]
|
|
|
|
while queue:
|
|
_, distance, label = heapq.heappop(queue)
|
|
|
|
if (label.row, label.col, label.direction, label.count) in visited:
|
|
continue
|
|
|
|
visited[label.row, label.col, label.direction, label.count] = label
|
|
per_cell[label.row, label.col].append((label, distance))
|
|
|
|
if (label.row, label.col) == target:
|
|
break
|
|
|
|
for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
|
|
# cannot move in the opposite direction
|
|
if label.direction == i_direction:
|
|
continue
|
|
|
|
# other direction, move 'min_straight' in the new direction
|
|
elif label.direction != direction:
|
|
row, col, count = (
|
|
label.row + min_straight * c_row,
|
|
label.col + min_straight * c_col,
|
|
min_straight,
|
|
)
|
|
|
|
# same direction, too many count
|
|
elif label.count == max_straight:
|
|
continue
|
|
|
|
# same direction, keep going and increment count
|
|
else:
|
|
row, col, count = (
|
|
label.row + c_row,
|
|
label.col + c_col,
|
|
label.count + 1,
|
|
)
|
|
# exclude labels outside the grid or with too many moves in the same
|
|
# direction
|
|
if row not in range(0, n_rows) or col not in range(0, n_cols):
|
|
continue
|
|
|
|
distance_to = (
|
|
distance
|
|
+ sum(
|
|
grid[r][c]
|
|
for r in range(min(row, label.row), max(row, label.row) + 1)
|
|
for c in range(min(col, label.col), max(col, label.col) + 1)
|
|
)
|
|
- grid[label.row][label.col]
|
|
)
|
|
|
|
heapq.heappush(
|
|
queue,
|
|
(
|
|
distance_to + lower_bounds[row, col],
|
|
distance_to,
|
|
Label(
|
|
row=row,
|
|
col=col,
|
|
direction=direction,
|
|
count=count,
|
|
parent=label,
|
|
),
|
|
),
|
|
)
|
|
|
|
if VERBOSE:
|
|
print_shortest_path(grid, target, per_cell)
|
|
|
|
return per_cell[target][0][1]
|
|
|
|
|
|
data = [[int(c) for c in r] for r in sys.stdin.read().splitlines()]
|
|
estimates = shortest_many_paths(data)
|
|
|
|
# part 1
|
|
answer_1 = shortest_path(data, 1, 3, lower_bounds=estimates)
|
|
print(f"answer 1 is {answer_1}")
|
|
|
|
# part 2
|
|
answer_2 = shortest_path(data, 4, 10, lower_bounds=estimates)
|
|
print(f"answer 2 is {answer_2}")
|