Compare commits

...

1 Commits

Author SHA1 Message Date
Mikael CAPELLE
40ab70271e 2023 day 17, v2. 2023-12-19 14:26:16 +01:00

View File

@ -18,7 +18,6 @@ class Label:
col: int col: int
direction: Direction direction: Direction
count: int
parent: Label | None = None parent: Label | None = None
@ -81,7 +80,7 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
visited: dict[tuple[int, int], tuple[Label, int]] = {} visited: dict[tuple[int, int], tuple[Label, int]] = {}
queue: list[tuple[int, Label]] = [ queue: list[tuple[int, Label]] = [
(0, Label(row=n_rows - 1, col=n_cols - 1, direction="^", count=0)) (0, Label(row=n_rows - 1, col=n_cols - 1, direction="^"))
] ]
while queue and len(visited) != n_rows * n_cols: while queue and len(visited) != n_rows * n_cols:
@ -117,7 +116,6 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]:
row=row, row=row,
col=col, col=col,
direction=direction, direction=direction,
count=0,
parent=label, parent=label,
), ),
), ),
@ -137,7 +135,7 @@ def shortest_path(
target = (len(grid) - 1, len(grid[0]) - 1) target = (len(grid) - 1, len(grid[0]) - 1)
# for each tuple (row, col, direction, count), the associated label when visited # for each tuple (row, col, direction, count), the associated label when visited
visited: dict[tuple[int, int, str, int], Label] = {} visited: dict[tuple[int, int, Direction], Label] = {}
# list of all visited labels for a cell (with associated distance) # list of all visited labels for a cell (with associated distance)
per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list) per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list)
@ -145,17 +143,17 @@ def shortest_path(
# need to add two start labels, otherwise one of the two possible direction will # need to add two start labels, otherwise one of the two possible direction will
# not be possible # not be possible
queue: list[tuple[int, int, Label]] = [ queue: list[tuple[int, int, Label]] = [
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^", count=0)), (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^")),
(lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<", count=0)), (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<")),
] ]
while queue: while queue:
_, distance, label = heapq.heappop(queue) _, distance, label = heapq.heappop(queue)
if (label.row, label.col, label.direction, label.count) in visited: if (label.row, label.col, label.direction) in visited:
continue continue
visited[label.row, label.col, label.direction, label.count] = label visited[label.row, label.col, label.direction] = label
per_cell[label.row, label.col].append((label, distance)) per_cell[label.row, label.col].append((label, distance))
if (label.row, label.col) == target: if (label.row, label.col) == target:
@ -163,42 +161,26 @@ def shortest_path(
for direction, (c_row, c_col, i_direction) in MAPPINGS.items(): for direction, (c_row, c_col, i_direction) in MAPPINGS.items():
# cannot move in the opposite direction # cannot move in the opposite direction
if label.direction == i_direction: if label.direction == i_direction or label.direction == direction:
continue continue
# other direction, move 'min_straight' in the new direction distance_to = distance
elif label.direction != direction:
row, col, count = ( for amount in range(1, max_straight + 1):
label.row + min_straight * c_row, row, col = (
label.col + min_straight * c_col, label.row + amount * c_row,
min_straight, label.col + amount * c_col,
) )
# same direction, too many count
elif label.count == max_straight:
continue
# same direction, keep going and increment count
else:
row, col, count = (
label.row + c_row,
label.col + c_col,
label.count + 1,
)
# exclude labels outside the grid or with too many moves in the same # exclude labels outside the grid or with too many moves in the same
# direction # direction
if row not in range(0, n_rows) or col not in range(0, n_cols): if not (0 <= row < n_rows and 0 <= col < n_cols):
continue break
distance_to = ( distance_to += grid[row][col]
distance
+ sum( if amount < min_straight:
grid[r][c] continue
for r in range(min(row, label.row), max(row, label.row) + 1)
for c in range(min(col, label.col), max(col, label.col) + 1)
)
- grid[label.row][label.col]
)
heapq.heappush( heapq.heappush(
queue, queue,
@ -209,7 +191,6 @@ def shortest_path(
row=row, row=row,
col=col, col=col,
direction=direction, direction=direction,
count=count,
parent=label, parent=label,
), ),
), ),