diff --git a/2022/day12.py b/2022/day12.py index 814505b..28c54e4 100644 --- a/2022/day12.py +++ b/2022/day12.py @@ -9,36 +9,20 @@ Node = TypeVar("Node") def dijkstra( start: Node, - end: Node, neighbors: Callable[[Node], Iterator[Node]], cost: Callable[[Node, Node], float], - heuristic: Callable[[Node, Node], float] | None = None, -) -> list[Node] | None: +) -> tuple[dict[Node, float], dict[Node, Node]]: - queue: list[tuple[tuple[float, float, float], Node]] = [] + queue: list[tuple[float, Node]] = [] visited: set[Node] = set() lengths: dict[Node, float] = {start: 0} parents: dict[Node, Node] = {} - if heuristic is None: + heapq.heappush(queue, (0, start)) - def priority(node: Node): - c = lengths[node] - return (c, c, c) - - else: - - def priority(node: Node): - assert heuristic is not None - h = heuristic(node, end) - c = lengths[node] - return (h + c, h, c) - - heapq.heappush(queue, (priority(start), start)) - - while queue and (end not in visited): - (_, _, length), current = heapq.heappop(queue) + while queue: + length, current = heapq.heappop(queue) if current in visited: continue @@ -53,12 +37,17 @@ def dijkstra( neighbor_cost = length + cost(current, neighbor) if neighbor_cost < lengths.get(neighbor, float("inf")): - lengths[neighbor] = length + 1 + lengths[neighbor] = neighbor_cost parents[neighbor] = current - heapq.heappush(queue, (priority(neighbor), neighbor)) + heapq.heappush(queue, (neighbor_cost, neighbor)) - if end not in visited: + return lengths, parents + + +def make_path(parents: dict[Node, Node], start: Node, end: Node) -> list[Node] | None: + + if end not in parents: return None path: list[Node] = [end] @@ -125,7 +114,7 @@ def heuristic(lhs: tuple[int, int], rhs: tuple[int, int]) -> float: return abs(lhs[0] - rhs[0]) + abs(lhs[1] - rhs[1]) -def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]: +def neighbors(node: tuple[int, int], up: bool) -> Iterator[tuple[int, int]]: c_row, c_col = node for n_row, n_col in ( (c_row - 1, c_col), @@ -137,38 +126,27 @@ def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]: if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols): continue - if grid[n_row][n_col] > grid[c_row][c_col] + 1: + if up and grid[n_row][n_col] > grid[c_row][c_col] + 1: + continue + elif not up and grid[n_row][n_col] < grid[c_row][c_col] - 1: continue yield n_row, n_col -path = dijkstra( - start=start, - end=end, - neighbors=neighbors, - cost=lambda lhs, rhs: 1, - heuristic=heuristic, +lengths_1, parents_1 = dijkstra( + start=start, neighbors=lambda n: neighbors(n, True), cost=lambda lhs, rhs: 1 ) -assert path is not None +path_1 = make_path(parents_1, start, end) +assert path_1 is not None -print_path(path, n_rows=len(grid), n_cols=len(grid[0])) +print_path(path_1, n_rows=len(grid), n_cols=len(grid[0])) -answer_1 = len(path) - 1 +answer_1 = lengths_1[end] - 1 print(f"answer 1 is {answer_1}") -answer_2 = min( - len(path) - 1 - for start in start_s - if ( - path := dijkstra( - start=start, - end=end, - neighbors=neighbors, - cost=lambda lhs, rhs: 1, - heuristic=heuristic, - ) - ) - is not None +lengths_2, parents_2 = dijkstra( + start=end, neighbors=lambda n: neighbors(n, False), cost=lambda lhs, rhs: 1 ) +answer_2 = min(lengths_2.get(start, float("inf")) for start in start_s) print(f"answer 2 is {answer_2}")