diff --git a/2023/day17.py b/2023/day17.py index 250af9b..8594265 100644 --- a/2023/day17.py +++ b/2023/day17.py @@ -18,7 +18,6 @@ class Label: col: int direction: Direction - count: int parent: Label | None = None @@ -81,7 +80,7 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]: visited: dict[tuple[int, int], tuple[Label, int]] = {} queue: list[tuple[int, Label]] = [ - (0, Label(row=n_rows - 1, col=n_cols - 1, direction="^", count=0)) + (0, Label(row=n_rows - 1, col=n_cols - 1, direction="^")) ] while queue and len(visited) != n_rows * n_cols: @@ -117,7 +116,6 @@ def shortest_many_paths(grid: list[list[int]]) -> dict[tuple[int, int], int]: row=row, col=col, direction=direction, - count=0, parent=label, ), ), @@ -137,7 +135,7 @@ def shortest_path( target = (len(grid) - 1, len(grid[0]) - 1) # for each tuple (row, col, direction, count), the associated label when visited - visited: dict[tuple[int, int, str, int], Label] = {} + visited: dict[tuple[int, int, Direction], Label] = {} # list of all visited labels for a cell (with associated distance) per_cell: dict[tuple[int, int], list[tuple[Label, int]]] = defaultdict(list) @@ -145,17 +143,17 @@ def shortest_path( # need to add two start labels, otherwise one of the two possible direction will # not be possible queue: list[tuple[int, int, Label]] = [ - (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^", count=0)), - (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<", count=0)), + (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="^")), + (lower_bounds[0, 0], 0, Label(row=0, col=0, direction="<")), ] while queue: _, distance, label = heapq.heappop(queue) - if (label.row, label.col, label.direction, label.count) in visited: + if (label.row, label.col, label.direction) in visited: continue - visited[label.row, label.col, label.direction, label.count] = label + visited[label.row, label.col, label.direction] = label per_cell[label.row, label.col].append((label, distance)) if (label.row, label.col) == target: @@ -163,57 +161,40 @@ def shortest_path( for direction, (c_row, c_col, i_direction) in MAPPINGS.items(): # cannot move in the opposite direction - if label.direction == i_direction: + if label.direction == i_direction or label.direction == direction: continue - # other direction, move 'min_straight' in the new direction - elif label.direction != direction: - row, col, count = ( - label.row + min_straight * c_row, - label.col + min_straight * c_col, - min_straight, + distance_to = distance + + for amount in range(1, max_straight + 1): + row, col = ( + label.row + amount * c_row, + label.col + amount * c_col, ) - # same direction, too many count - elif label.count == max_straight: - continue + # exclude labels outside the grid or with too many moves in the same + # direction + if not (0 <= row < n_rows and 0 <= col < n_cols): + break - # same direction, keep going and increment count - else: - row, col, count = ( - label.row + c_row, - label.col + c_col, - label.count + 1, - ) - # exclude labels outside the grid or with too many moves in the same - # direction - if row not in range(0, n_rows) or col not in range(0, n_cols): - continue + distance_to += grid[row][col] - distance_to = ( - distance - + sum( - grid[r][c] - for r in range(min(row, label.row), max(row, label.row) + 1) - for c in range(min(col, label.col), max(col, label.col) + 1) - ) - - grid[label.row][label.col] - ) + if amount < min_straight: + continue - heapq.heappush( - queue, - ( - distance_to + lower_bounds[row, col], - distance_to, - Label( - row=row, - col=col, - direction=direction, - count=count, - parent=label, + heapq.heappush( + queue, + ( + distance_to + lower_bounds[row, col], + distance_to, + Label( + row=row, + col=col, + direction=direction, + parent=label, + ), ), - ), - ) + ) if VERBOSE: print_shortest_path(grid, target, per_cell)