One to many Dijkstra.
This commit is contained in:
		| @@ -9,36 +9,20 @@ Node = TypeVar("Node") | |||||||
|  |  | ||||||
| def dijkstra( | def dijkstra( | ||||||
|     start: Node, |     start: Node, | ||||||
|     end: Node, |  | ||||||
|     neighbors: Callable[[Node], Iterator[Node]], |     neighbors: Callable[[Node], Iterator[Node]], | ||||||
|     cost: Callable[[Node, Node], float], |     cost: Callable[[Node, Node], float], | ||||||
|     heuristic: Callable[[Node, Node], float] | None = None, | ) -> tuple[dict[Node, float], dict[Node, Node]]: | ||||||
| ) -> list[Node] | None: |  | ||||||
|  |  | ||||||
|     queue: list[tuple[tuple[float, float, float], Node]] = [] |     queue: list[tuple[float, Node]] = [] | ||||||
|  |  | ||||||
|     visited: set[Node] = set() |     visited: set[Node] = set() | ||||||
|     lengths: dict[Node, float] = {start: 0} |     lengths: dict[Node, float] = {start: 0} | ||||||
|     parents: dict[Node, Node] = {} |     parents: dict[Node, Node] = {} | ||||||
|  |  | ||||||
|     if heuristic is None: |     heapq.heappush(queue, (0, start)) | ||||||
|  |  | ||||||
|         def priority(node: Node): |     while queue: | ||||||
|             c = lengths[node] |         length, current = heapq.heappop(queue) | ||||||
|             return (c, c, c) |  | ||||||
|  |  | ||||||
|     else: |  | ||||||
|  |  | ||||||
|         def priority(node: Node): |  | ||||||
|             assert heuristic is not None |  | ||||||
|             h = heuristic(node, end) |  | ||||||
|             c = lengths[node] |  | ||||||
|             return (h + c, h, c) |  | ||||||
|  |  | ||||||
|     heapq.heappush(queue, (priority(start), start)) |  | ||||||
|  |  | ||||||
|     while queue and (end not in visited): |  | ||||||
|         (_, _, length), current = heapq.heappop(queue) |  | ||||||
|  |  | ||||||
|         if current in visited: |         if current in visited: | ||||||
|             continue |             continue | ||||||
| @@ -53,12 +37,17 @@ def dijkstra( | |||||||
|             neighbor_cost = length + cost(current, neighbor) |             neighbor_cost = length + cost(current, neighbor) | ||||||
|  |  | ||||||
|             if neighbor_cost < lengths.get(neighbor, float("inf")): |             if neighbor_cost < lengths.get(neighbor, float("inf")): | ||||||
|                 lengths[neighbor] = length + 1 |                 lengths[neighbor] = neighbor_cost | ||||||
|                 parents[neighbor] = current |                 parents[neighbor] = current | ||||||
|  |  | ||||||
|                 heapq.heappush(queue, (priority(neighbor), neighbor)) |                 heapq.heappush(queue, (neighbor_cost, neighbor)) | ||||||
|  |  | ||||||
|     if end not in visited: |     return lengths, parents | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def make_path(parents: dict[Node, Node], start: Node, end: Node) -> list[Node] | None: | ||||||
|  |  | ||||||
|  |     if end not in parents: | ||||||
|         return None |         return None | ||||||
|  |  | ||||||
|     path: list[Node] = [end] |     path: list[Node] = [end] | ||||||
| @@ -125,7 +114,7 @@ def heuristic(lhs: tuple[int, int], rhs: tuple[int, int]) -> float: | |||||||
|     return abs(lhs[0] - rhs[0]) + abs(lhs[1] - rhs[1]) |     return abs(lhs[0] - rhs[0]) + abs(lhs[1] - rhs[1]) | ||||||
|  |  | ||||||
|  |  | ||||||
| def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]: | def neighbors(node: tuple[int, int], up: bool) -> Iterator[tuple[int, int]]: | ||||||
|     c_row, c_col = node |     c_row, c_col = node | ||||||
|     for n_row, n_col in ( |     for n_row, n_col in ( | ||||||
|         (c_row - 1, c_col), |         (c_row - 1, c_col), | ||||||
| @@ -137,38 +126,27 @@ def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]: | |||||||
|         if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols): |         if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols): | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
|         if grid[n_row][n_col] > grid[c_row][c_col] + 1: |         if up and grid[n_row][n_col] > grid[c_row][c_col] + 1: | ||||||
|  |             continue | ||||||
|  |         elif not up and grid[n_row][n_col] < grid[c_row][c_col] - 1: | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
|         yield n_row, n_col |         yield n_row, n_col | ||||||
|  |  | ||||||
|  |  | ||||||
| path = dijkstra( | lengths_1, parents_1 = dijkstra( | ||||||
|     start=start, |     start=start, neighbors=lambda n: neighbors(n, True), cost=lambda lhs, rhs: 1 | ||||||
|     end=end, |  | ||||||
|     neighbors=neighbors, |  | ||||||
|     cost=lambda lhs, rhs: 1, |  | ||||||
|     heuristic=heuristic, |  | ||||||
| ) | ) | ||||||
| assert path is not None | path_1 = make_path(parents_1, start, end) | ||||||
|  | assert path_1 is not None | ||||||
|  |  | ||||||
| print_path(path, n_rows=len(grid), n_cols=len(grid[0])) | print_path(path_1, n_rows=len(grid), n_cols=len(grid[0])) | ||||||
|  |  | ||||||
| answer_1 = len(path) - 1 | answer_1 = lengths_1[end] - 1 | ||||||
| print(f"answer 1 is {answer_1}") | print(f"answer 1 is {answer_1}") | ||||||
|  |  | ||||||
| answer_2 = min( | lengths_2, parents_2 = dijkstra( | ||||||
|     len(path) - 1 |     start=end, neighbors=lambda n: neighbors(n, False), cost=lambda lhs, rhs: 1 | ||||||
|     for start in start_s |  | ||||||
|     if ( |  | ||||||
|         path := dijkstra( |  | ||||||
|             start=start, |  | ||||||
|             end=end, |  | ||||||
|             neighbors=neighbors, |  | ||||||
|             cost=lambda lhs, rhs: 1, |  | ||||||
|             heuristic=heuristic, |  | ||||||
|         ) |  | ||||||
|     ) |  | ||||||
|     is not None |  | ||||||
| ) | ) | ||||||
|  | answer_2 = min(lengths_2.get(start, float("inf")) for start in start_s) | ||||||
| print(f"answer 2 is {answer_2}") | print(f"answer 2 is {answer_2}") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user