From 726a6aecac93feba9444cac36781bb0460d80327 Mon Sep 17 00:00:00 2001 From: Mikael CAPELLE Date: Mon, 12 Dec 2022 15:59:19 +0100 Subject: [PATCH] Generic Dijkstra for day 12. --- 2022/day12.py | 127 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 83 insertions(+), 44 deletions(-) diff --git a/2022/day12.py b/2022/day12.py index f84587f..814505b 100644 --- a/2022/day12.py +++ b/2022/day12.py @@ -2,77 +2,75 @@ import heapq import sys -from typing import Iterator +from typing import Callable, Iterator, TypeVar + +Node = TypeVar("Node") def dijkstra( - start: tuple[int, int], end: tuple[int, int], grid: list[list[int]] -) -> list[tuple[int, int]] | None: - n_rows = len(grid) - n_cols = len(grid[0]) + start: Node, + end: Node, + neighbors: Callable[[Node], Iterator[Node]], + cost: Callable[[Node, Node], float], + heuristic: Callable[[Node, Node], float] | None = None, +) -> list[Node] | None: - def heuristic(row: int, col: int) -> int: - return abs(end[0] - row) + abs(end[1] - col) + queue: list[tuple[tuple[float, float, float], Node]] = [] - def neighbors(row: int, col: int) -> Iterator[tuple[int, int]]: - for n_row, n_col in ( - (c_row - 1, c_col), - (c_row + 1, c_col), - (c_row, c_col - 1), - (c_row, c_col + 1), - ): + visited: set[Node] = set() + lengths: dict[Node, float] = {start: 0} + parents: dict[Node, Node] = {} - if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols): - continue + if heuristic is None: - if grid[n_row][n_col] > grid[c_row][c_col] + 1: - continue + def priority(node: Node): + c = lengths[node] + return (c, c, c) - yield n_row, n_col + else: - queue: list[tuple[tuple[int, int], tuple[int, int]]] = [] + def priority(node: Node): + assert heuristic is not None + h = heuristic(node, end) + c = lengths[node] + return (h + c, h, c) - visited: set[tuple[int, int]] = set() - lengths: dict[tuple[int, int], int] = {} - parents: dict[tuple[int, int], tuple[int, int]] = {} - - heapq.heappush(queue, ((heuristic(start[0], start[1]), 0), start)) + heapq.heappush(queue, (priority(start), start)) while queue and (end not in visited): - (_, length), (c_row, c_col) = heapq.heappop(queue) + (_, _, length), current = heapq.heappop(queue) - visited.add((c_row, c_col)) + if current in visited: + continue - for n_row, n_col in neighbors(c_row, c_col): + visited.add(current) - if (n_row, n_col) in visited: + for neighbor in neighbors(current): + + if neighbor in visited: continue - if length + 1 < lengths.get((n_row, n_col), n_rows * n_cols): - lengths[n_row, n_col] = length + 1 - parents[n_row, n_col] = (c_row, c_col) + neighbor_cost = length + cost(current, neighbor) - heapq.heappush( - queue, - ( - (heuristic(n_row, n_col) + length + 1, length + 1), - (n_row, n_col), - ), - ) + if neighbor_cost < lengths.get(neighbor, float("inf")): + lengths[neighbor] = length + 1 + parents[neighbor] = current + + heapq.heappush(queue, (priority(neighbor), neighbor)) if end not in visited: return None - path: list[tuple[int, int]] = [end] + path: list[Node] = [end] - while path[-1] != start: + while path[-1] is not start: path.append(parents[path[-1]]) return list(reversed(path)) def print_path(path: list[tuple[int, int]], n_rows: int, n_cols: int) -> None: - _, end = path[0], path[-1] + end = path[-1] graph = [["." for _c in range(n_cols)] for _r in range(n_rows)] graph[end[0]][end[1]] = "E" @@ -119,7 +117,39 @@ for i_row, row in enumerate(grid): grid[start[0]][start[1]] = 0 grid[end[0]][end[1]] = ord("z") - ord("a") -path = dijkstra(start, end, grid) +n_rows = len(grid) +n_cols = len(grid[0]) + + +def heuristic(lhs: tuple[int, int], rhs: tuple[int, int]) -> float: + return abs(lhs[0] - rhs[0]) + abs(lhs[1] - rhs[1]) + + +def neighbors(node: tuple[int, int]) -> Iterator[tuple[int, int]]: + c_row, c_col = node + for n_row, n_col in ( + (c_row - 1, c_col), + (c_row + 1, c_col), + (c_row, c_col - 1), + (c_row, c_col + 1), + ): + + if not (n_row >= 0 and n_row < n_rows and n_col >= 0 and n_col < n_cols): + continue + + if grid[n_row][n_col] > grid[c_row][c_col] + 1: + continue + + yield n_row, n_col + + +path = dijkstra( + start=start, + end=end, + neighbors=neighbors, + cost=lambda lhs, rhs: 1, + heuristic=heuristic, +) assert path is not None print_path(path, n_rows=len(grid), n_cols=len(grid[0])) @@ -130,6 +160,15 @@ print(f"answer 1 is {answer_1}") answer_2 = min( len(path) - 1 for start in start_s - if (path := dijkstra(start, end, grid)) is not None + if ( + path := dijkstra( + start=start, + end=end, + neighbors=neighbors, + cost=lambda lhs, rhs: 1, + heuristic=heuristic, + ) + ) + is not None ) print(f"answer 2 is {answer_2}")