2023 day 17, v2.

This commit is contained in:
Mikael CAPELLE 2023-12-19 14:26:16 +01:00
parent f15908876d
commit 40ab70271e

View File

@ -18,7 +18,6 @@ class Label:
col: int col: int
direction: Direction direction: Direction
count: int
parent: Label | None = None parent: Label | None = None
@ -81,7 +80,7 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
visited: dict[tuple[int, int], tuple[Label, int]] = {} visited: dict[tuple[int, int], tuple[Label, int]] = {}
queue: list[tuple[int, Label]] = [ queue: list[tuple[int, Label]] = [
(0, Label(row=n_rows - 1, col=n_cols - 1, direction="^", count=0)) (0, Label(row=n_rows - 1, col=n_cols - 1, direction="^"))
] ]
while queue and len(visited) != n_rows * n_cols: while queue and len(visited) != n_rows * n_cols:
@ -117,7 +116,6 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
row=row, row=row,
col=col, col=col,
direction=direction, direction=direction,
count=0,
parent=label, parent=label,
), ),
), ),
@ -137,7 +135,7 @@ def shortest_path(
target = (len(grid) - 1, len(grid[0]) - 1) target = (len(grid) - 1, len(grid[0]) - 1)
# for each tuple (row, col, direction, count), the associated label when visited # for each tuple (row, col, direction, count), the associated label when visited
visited: dict[tuple[int, int, str, int], Label] = {} visited: dict[tuple[int, int, Direction], Label] = {}
# list of all visited labels for a cell (with associated distance) # list of all visited labels for a cell (with associated distance)
per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list) per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list)
@ -145,17 +143,17 @@ def shortest_path(
# need to add two start labels, otherwise one of the two possible direction will # need to add two start labels, otherwise one of the two possible direction will
# not be possible # not be possible
queue: list[tuple[int, int, Label]] = [ queue: list[tuple[int, int, Label]] = [
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^", count=0)), (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^")),
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<", count=0)), (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<")),
] ]
while queue: while queue:
_, distance, label = heapq.heappop(queue) _, distance, label = heapq.heappop(queue)
if (label.row, label.col, label.direction, label.count) in visited: if (label.row, label.col, label.direction) in visited:
continue continue
visited[label.row, label.col, label.direction, label.count] = label visited[label.row, label.col, label.direction] = label
per_cell[label.row, label.col].append((label, distance)) per_cell[label.row, label.col].append((label, distance))
if (label.row, label.col) == target: if (label.row, label.col) == target:
@ -163,57 +161,40 @@ def shortest_path(
for direction, (c_row, c_col, i_direction) in MAPPINGS.items(): for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
# cannot move in the opposite direction # cannot move in the opposite direction
if label.direction == i_direction: if label.direction == i_direction or label.direction == direction:
continue continue
# other direction, move 'min_straight' in the new direction distance_to = distance
elif label.direction != direction:
row, col, count = ( for amount in range(1, max_straight + 1):
label.row + min_straight * c_row, row, col = (
label.col + min_straight * c_col, label.row + amount * c_row,
min_straight, label.col + amount * c_col,
) )
# same direction, too many count # exclude labels outside the grid or with too many moves in the same
elif label.count == max_straight: # direction
continue if not (0 <= row < n_rows and 0 <= col < n_cols):
break
# same direction, keep going and increment count distance_to += grid[row][col]
else:
row, col, count = (
label.row + c_row,
label.col + c_col,
label.count + 1,
)
# exclude labels outside the grid or with too many moves in the same
# direction
if row not in range(0, n_rows) or col not in range(0, n_cols):
continue
distance_to = ( if amount < min_straight:
distance continue
+ sum(
grid[r][c]
for r in range(min(row, label.row), max(row, label.row) + 1)
for c in range(min(col, label.col), max(col, label.col) + 1)
)
- grid[label.row][label.col]
)
heapq.heappush( heapq.heappush(
queue, queue,
( (
distance_to + lower_bounds[row, col], distance_to + lower_bounds[row, col],
distance_to, distance_to,
Label( Label(
row=row, row=row,
col=col, col=col,
direction=direction, direction=direction,
count=count, parent=label,
parent=label, ),
), ),
), )
)
if VERBOSE: if VERBOSE:
print_shortest_path(grid, target, per_cell) print_shortest_path(grid, target, per_cell)