From 146d025d41c03e796fd02e4b12f2286eac896485 Mon Sep 17 00:00:00 2001 From: Mikael CAPELLE Date: Wed, 18 Dec 2024 09:01:31 +0100 Subject: [PATCH] Add generic simple dijkstra method. --- src/holt59/aoc/2024/day18.py | 89 ++++++++++++------------------- src/holt59/aoc/tools/graphs.py | 95 ++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 56 deletions(-) create mode 100644 src/holt59/aoc/tools/graphs.py diff --git a/src/holt59/aoc/2024/day18.py b/src/holt59/aoc/2024/day18.py index d9627cf..11ec1a7 100644 --- a/src/holt59/aoc/2024/day18.py +++ b/src/holt59/aoc/2024/day18.py @@ -1,81 +1,58 @@ -import heapq -from typing import Any, Iterator, TypeAlias, cast +from typing import Any, Iterator from ..base import BaseSolver - -Node: TypeAlias = tuple[int, int] - - -def dijkstra( - grid: list[Node], - n_rows: int, - n_cols: int, - start: Node = (0, 0), - target: Node | None = None, -) -> tuple[Node, ...] | None: - corrupted = set(grid) - target = target or (n_rows - 1, n_cols - 1) - - queue: list[tuple[int, Node, tuple[Node, ...]]] = [(0, start, (start,))] - preds: dict[Node, tuple[Node, ...]] = {} - - while queue: - dis, node, path = heapq.heappop(queue) - - if node in preds: - continue - - preds[node] = path - - if node == target: - break - - row, col = node - for dr, dc in ((-1, 0), (0, 1), (1, 0), (0, -1)): - row_n, col_n = row + dr, col + dc - - if ( - 0 <= row_n < n_rows - and 0 <= col_n < n_cols - and (row_n, col_n) not in corrupted - and (row_n, col_n) not in preds - ): - heapq.heappush( - queue, (dis + 1, (row_n, col_n), path + ((row_n, col_n),)) - ) - - return preds.get(target, None) +from ..tools import graphs class Solver(BaseSolver): def print_grid(self, grid: list[tuple[int, int]], n_rows: int, n_cols: int): values = set(grid) - for row in range(n_rows): - self.logger.info( - "".join("#" if (row, col) in values else "." for col in range(n_cols)) + if self.files: + self.files.create( + "graph.txt", + "\n".join( + "".join( + "#" if (row, col) in values else "." for col in range(n_cols) + ) + for row in range(n_rows) + ).encode(), + text=True, ) + else: + for row in range(n_rows): + self.logger.info( + "".join( + "#" if (row, col) in values else "." for col in range(n_cols) + ) + ) + + def dijkstra(self, corrupted: list[tuple[int, int]], n_rows: int, n_cols: int): + return graphs.dijkstra( + (0, 0), + (n_rows - 1, n_cols - 1), + graphs.make_neighbors_grid_fn(n_rows, n_cols, set(corrupted)), + ) def solve(self, input: str) -> Iterator[Any]: values = [ - cast(tuple[int, int], tuple(map(int, row.split(",")))) - for row in input.splitlines() + (int(p[0]), int(p[1])) for r in input.splitlines() if (p := r.split(",")) ] - n_rows, n_cols = (7, 7) if len(values) < 100 else (71, 71) + _is_test = len(values) < 100 + + n_rows, n_cols, n_bytes_p1 = (7, 7, 12) if _is_test else (71, 71, 1024) - n_bytes_p1 = 12 if len(values) < 100 else 1024 bytes_p1 = values[:n_bytes_p1] self.print_grid(bytes_p1, n_rows, n_cols) - path_p1 = dijkstra(bytes_p1, n_rows, n_cols) - assert path_p1 is not None - yield len(path_p1) - 1 + path_p1, cost_p1 = self.dijkstra(bytes_p1, n_rows, n_cols) or ((), -1) + yield cost_p1 path = path_p1 for b in range(n_bytes_p1, len(values)): if values[b] not in path: continue - path = dijkstra(values[: b + 1], n_rows, n_cols) + path, _ = self.dijkstra(values[: b + 1], n_rows, n_cols) or (None, -1) if path is None: yield ",".join(map(str, values[b])) break diff --git a/src/holt59/aoc/tools/graphs.py b/src/holt59/aoc/tools/graphs.py new file mode 100644 index 0000000..4f301ae --- /dev/null +++ b/src/holt59/aoc/tools/graphs.py @@ -0,0 +1,95 @@ +import heapq +from typing import Callable, Iterable, TypeVar + +_Node = TypeVar("_Node") + + +def make_neighbors_grid_fn( + rows: int | Iterable[int], + cols: int | Iterable[int], + excluded: Iterable[tuple[int, int]] = set(), + diagonals: bool = False, +): + """ + Create a neighbors function suitable for graph function for a simple grid. + + Args: + rows: Rows of the grid. If an int is specified, the rows are assumed to be + numbered from 0 to rows - 1, otherwise the iterable should contain the list + of valid rows. + cols: Columns of the grid. If an int is specified, the columns are assumed to be + numbered from 0 to cols - 1, otherwise the iterable should contain the list + of valid columns. + excluded: Cells of the grid that cannot be used as valid nodes for the graph. + diagonals: If True, neighbors will include diagonal cells, otherwise, only + horizontal and vertical neighbors will be included. + + """ + ds = ((-1, 0), (0, 1), (1, 0), (0, -1)) + if diagonals: + ds = ds + ((-1, -1), (-1, 1), (1, -1), (1, 1)) + + if isinstance(rows, int): + rows = range(rows) + elif not isinstance(rows, range): + rows = set(rows) + + if isinstance(cols, int): + cols = range(cols) + elif not isinstance(cols, range): + cols = set(cols) + + excluded = set(excluded) + + def _fn(node: tuple[int, int]): + return ( + ((row_n, col_n), 1) + for dr, dc in ds + if (row_n := node[0] + dr) in rows + and (col_n := node[1] + dc) in cols + and (row_n, col_n) not in excluded + ) + + return _fn + + +def dijkstra( + start: _Node, + target: _Node, + neighbors: Callable[[_Node], Iterable[tuple[_Node, float]]], +) -> tuple[tuple[_Node, ...], float] | None: + """ + Solve shortest-path problem using simple Dijkstra algorithm from start to target, + using the given neighbors function. + + Args: + start: Starting node of the path. + target: Target node for the path. + neighbors: Function that should return, for a given node, the list of + its neighbors with the cost to go from the node to the neighbor. + + Returns: + One of the shortest-path from start to target with its associated cost, if one + is found, otherwise None. + """ + queue: list[tuple[float, _Node, tuple[_Node, ...]]] = [(0, start, (start,))] + preds: dict[_Node, tuple[tuple[_Node, ...], float]] = {} + + while queue: + dis, node, path = heapq.heappop(queue) + + if node in preds: + continue + + preds[node] = (path, dis) + + if node == target: + break + + for neighbor, cost in neighbors(node): + if neighbor in preds: + continue + + heapq.heappush(queue, (dis + cost, neighbor, path + (neighbor,))) + + return preds.get(target, None)