One to many Dijkstra.
This commit is contained in:
parent
4899583b15
commit
38b1d86514
@ -9,36 +9,20 @@ Node = TypeVar("Node")
|
|||||||
|
|
||||||
def dijkstra(
|
def dijkstra(
|
||||||
start: Node,
|
start: Node,
|
||||||
end: Node,
|
|
||||||
neighbors: Callable[[Node], Iterator[Node]],
|
neighbors: Callable[[Node], Iterator[Node]],
|
||||||
cost: Callable[[Node, Node], float],
|
cost: Callable[[Node, Node], float],
|
||||||
heuristic: Callable[[Node, Node], float] | None = None,
|
) -> tuple[dict[Node, float], dict[Node, Node]]:
|
||||||
) -> list[Node] | None:
|
|
||||||
|
|
||||||
queue: list[tuple[tuple[float, float, float], Node]] = []
|
queue: list[tuple[float, Node]] = []
|
||||||
|
|
||||||
visited: set[Node] = set()
|
visited: set[Node] = set()
|
||||||
lengths: dict[Node, float] = {start: 0}
|
lengths: dict[Node, float] = {start: 0}
|
||||||
parents: dict[Node, Node] = {}
|
parents: dict[Node, Node] = {}
|
||||||
|
|
||||||
if heuristic is None:
|
heapq.heappush(queue, (0, start))
|
||||||
|
|
||||||
def priority(node: Node):
|
while queue:
|
||||||
c = lengths[node]
|
length, current = heapq.heappop(queue)
|
||||||
return (c, c, c)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def priority(node: Node):
|
|
||||||
assert heuristic is not None
|
|
||||||
h = heuristic(node, end)
|
|
||||||
c = lengths[node]
|
|
||||||
return (h + c, h, c)
|
|
||||||
|
|
||||||
heapq.heappush(queue, (priority(start), start))
|
|
||||||
|
|
||||||
while queue and (end not in visited):
|
|
||||||
(_, _, length), current = heapq.heappop(queue)
|
|
||||||
|
|
||||||
if current in visited:
|
if current in visited:
|
||||||
continue
|
continue
|
||||||
@ -53,12 +37,17 @@ def dijkstra(
|
|||||||
neighbor_cost = length + cost(current, neighbor)
|
neighbor_cost = length + cost(current, neighbor)
|
||||||
|
|
||||||
if neighbor_cost < lengths.get(neighbor, float("inf")):
|
if neighbor_cost < lengths.get(neighbor, float("inf")):
|
||||||
lengths[neighbor] = length + 1
|
lengths[neighbor] = neighbor_cost
|
||||||
parents[neighbor] = current
|
parents[neighbor] = current
|
||||||
|
|
||||||
heapq.heappush(queue, (priority(neighbor), neighbor))
|
heapq.heappush(queue, (neighbor_cost, neighbor))
|
||||||
|
|
||||||
if end not in visited:
|
return lengths, parents
|
||||||
|
|
||||||
|
|
||||||
|
def make_path(parents: dict[Node, Node], start: Node, end: Node) -> list[Node] | None:
|
||||||
|
|
||||||
|
if end not in parents:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
path: list[Node] = [end]
|
path: list[Node] = [end]
|
||||||
@ -125,7 +114,7 @@ def heuristic(lhs: tuple[int, int], rhs: tuple[int, int]) -> float:
|
|||||||
return abs(lhs[0] - rhs[0]) + abs(lhs[1] - rhs[1])
|
return abs(lhs[0] - rhs[0]) + abs(lhs[1] - rhs[1])
|
||||||
|
|
||||||
|
|
||||||
def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]:
|
def neighbors(node: tuple[int, int], up: bool) -> Iterator[tuple[int, int]]:
|
||||||
c_row, c_col = node
|
c_row, c_col = node
|
||||||
for n_row, n_col in (
|
for n_row, n_col in (
|
||||||
(c_row - 1, c_col),
|
(c_row - 1, c_col),
|
||||||
@ -137,38 +126,27 @@ def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]:
|
|||||||
if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols):
|
if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if grid[n_row][n_col] > grid[c_row][c_col] + 1:
|
if up and grid[n_row][n_col] > grid[c_row][c_col] + 1:
|
||||||
|
continue
|
||||||
|
elif not up and grid[n_row][n_col] < grid[c_row][c_col] - 1:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield n_row, n_col
|
yield n_row, n_col
|
||||||
|
|
||||||
|
|
||||||
path = dijkstra(
|
lengths_1, parents_1 = dijkstra(
|
||||||
start=start,
|
start=start, neighbors=lambda n: neighbors(n, True), cost=lambda lhs, rhs: 1
|
||||||
end=end,
|
|
||||||
neighbors=neighbors,
|
|
||||||
cost=lambda lhs, rhs: 1,
|
|
||||||
heuristic=heuristic,
|
|
||||||
)
|
)
|
||||||
assert path is not None
|
path_1 = make_path(parents_1, start, end)
|
||||||
|
assert path_1 is not None
|
||||||
|
|
||||||
print_path(path, n_rows=len(grid), n_cols=len(grid[0]))
|
print_path(path_1, n_rows=len(grid), n_cols=len(grid[0]))
|
||||||
|
|
||||||
answer_1 = len(path) - 1
|
answer_1 = lengths_1[end] - 1
|
||||||
print(f"answer 1 is {answer_1}")
|
print(f"answer 1 is {answer_1}")
|
||||||
|
|
||||||
answer_2 = min(
|
lengths_2, parents_2 = dijkstra(
|
||||||
len(path) - 1
|
start=end, neighbors=lambda n: neighbors(n, False), cost=lambda lhs, rhs: 1
|
||||||
for start in start_s
|
|
||||||
if (
|
|
||||||
path := dijkstra(
|
|
||||||
start=start,
|
|
||||||
end=end,
|
|
||||||
neighbors=neighbors,
|
|
||||||
cost=lambda lhs, rhs: 1,
|
|
||||||
heuristic=heuristic,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
is not None
|
|
||||||
)
|
)
|
||||||
|
answer_2 = min(lengths_2.get(start, float("inf")) for start in start_s)
|
||||||
print(f"answer 2 is {answer_2}")
|
print(f"answer 2 is {answer_2}")
|
||||||
|
Loading…
Reference in New Issue
Block a user