Compare commits

...

1 Commits

Author SHA1 Message Date
Mikael CAPELLE
40ab70271e 2023 day 17, v2. 2023-12-19 14:26:16 +01:00

View File

@ -18,7 +18,6 @@ class Label:
col: int
direction: Direction
count: int
parent: Label | None = None
@ -81,7 +80,7 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
visited: dict[tuple[int, int], tuple[Label, int]] = {}
queue: list[tuple[int, Label]] = [
(0, Label(row=n_rows - 1, col=n_cols - 1, direction="^", count=0))
(0, Label(row=n_rows - 1, col=n_cols - 1, direction="^"))
]
while queue and len(visited) != n_rows * n_cols:
@ -117,7 +116,6 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
row=row,
col=col,
direction=direction,
count=0,
parent=label,
),
),
@ -137,7 +135,7 @@ def shortest_path(
target = (len(grid) - 1, len(grid[0]) - 1)
# for each tuple (row, col, direction, count), the associated label when visited
visited: dict[tuple[int, int, str, int], Label] = {}
visited: dict[tuple[int, int, Direction], Label] = {}
# list of all visited labels for a cell (with associated distance)
per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list)
@ -145,17 +143,17 @@ def shortest_path(
# need to add two start labels, otherwise one of the two possible direction will
# not be possible
queue: list[tuple[int, int, Label]] = [
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^", count=0)),
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<", count=0)),
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^")),
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<")),
]
while queue:
_, distance, label = heapq.heappop(queue)
if (label.row, label.col, label.direction, label.count) in visited:
if (label.row, label.col, label.direction) in visited:
continue
visited[label.row, label.col, label.direction, label.count] = label
visited[label.row, label.col, label.direction] = label
per_cell[label.row, label.col].append((label, distance))
if (label.row, label.col) == target:
@ -163,42 +161,26 @@ def shortest_path(
for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
# cannot move in the opposite direction
if label.direction == i_direction:
if label.direction == i_direction or label.direction == direction:
continue
# other direction, move 'min_straight' in the new direction
elif label.direction != direction:
row, col, count = (
label.row + min_straight * c_row,
label.col + min_straight * c_col,
min_straight,
distance_to = distance
for amount in range(1, max_straight + 1):
row, col = (
label.row + amount * c_row,
label.col + amount * c_col,
)
# same direction, too many count
elif label.count == max_straight:
continue
# same direction, keep going and increment count
else:
row, col, count = (
label.row + c_row,
label.col + c_col,
label.count + 1,
)
# exclude labels outside the grid or with too many moves in the same
# direction
if row not in range(0, n_rows) or col not in range(0, n_cols):
continue
if not (0 <= row < n_rows and 0 <= col < n_cols):
break
distance_to = (
distance
+ sum(
grid[r][c]
for r in range(min(row, label.row), max(row, label.row) + 1)
for c in range(min(col, label.col), max(col, label.col) + 1)
)
- grid[label.row][label.col]
)
distance_to += grid[row][col]
if amount < min_straight:
continue
heapq.heappush(
queue,
@ -209,7 +191,6 @@ def shortest_path(
row=row,
col=col,
direction=direction,
count=count,
parent=label,
),
),