import heapq from typing import ( Callable, Iterable, Iterator, Mapping, TypeVar, cast, overload, ) _Node = TypeVar("_Node") def make_neighbors_grid_fn( rows: int | Iterable[int], cols: int | Iterable[int], excluded: Iterable[tuple[int, int]] = set(), diagonals: bool = False, ): """ Create a neighbors function suitable for graph function for a simple grid. Args: rows: Rows of the grid. If an int is specified, the rows are assumed to be numbered from 0 to rows - 1, otherwise the iterable should contain the list of valid rows. cols: Columns of the grid. If an int is specified, the columns are assumed to be numbered from 0 to cols - 1, otherwise the iterable should contain the list of valid columns. excluded: Cells of the grid that cannot be used as valid nodes for the graph. diagonals: If True, neighbors will include diagonal cells, otherwise, only horizontal and vertical neighbors will be included. """ ds = ((-1, 0), (0, 1), (1, 0), (0, -1)) if diagonals: ds = ds + ((-1, -1), (-1, 1), (1, -1), (1, 1)) if isinstance(rows, int): rows = range(rows) elif not isinstance(rows, range): rows = set(rows) if isinstance(cols, int): cols = range(cols) elif not isinstance(cols, range): cols = set(cols) excluded = set(excluded) def _fn(node: tuple[int, int]): return ( ((row_n, col_n), 1) for dr, dc in ds if (row_n := node[0] + dr) in rows and (col_n := node[1] + dc) in cols and (row_n, col_n) not in excluded ) return _fn @overload def dijkstra( start: _Node, target: None, neighbors: Callable[[_Node], Iterable[tuple[_Node, float]]], ) -> dict[_Node, tuple[tuple[_Node, ...], float]]: ... @overload def dijkstra( start: _Node, target: _Node, neighbors: Callable[[_Node], Iterable[tuple[_Node, float]]], ) -> tuple[tuple[_Node, ...], float] | None: ... def dijkstra( start: _Node, target: _Node | None, neighbors: Callable[[_Node], Iterable[tuple[_Node, float]]], ) -> ( dict[_Node, tuple[tuple[_Node, ...], float]] | tuple[tuple[_Node, ...], float] | None ): """ Solve shortest-path problem using simple Dijkstra algorithm from start to target, using the given neighbors function. Args: start: Starting node of the path. target: Target node for the path. neighbors: Function that should return, for a given node, the list of its neighbors with the cost to go from the node to the neighbor. Returns: One of the shortest-path from start to target with its associated cost, if one is found, otherwise None. """ queue: list[tuple[float, _Node, tuple[_Node, ...]]] = [(0, start, (start,))] preds: dict[_Node, tuple[tuple[_Node, ...], float]] = {} while queue: dis, node, path = heapq.heappop(queue) if node in preds: continue preds[node] = (path, dis) if node == target: break for neighbor, cost in neighbors(node): if neighbor in preds: continue heapq.heappush(queue, (dis + cost, neighbor, path + (neighbor,))) if target is None: return preds return preds.get(target, None) def iter_max_cliques( neighbors: Mapping[_Node, Iterable[_Node]], nodes: Iterable[_Node] | None = None ) -> Iterator[list[_Node]]: """ Find max cliques from the given set of neighbors containing the given set of nodes. This is simply the networkx implementation with typing (and using a simple mapping to avoid requiring networkx). """ if len(neighbors) == 0: return # remove the node itself from the neighbors adj = {u: {v for v in neighbors[u] if v != u} for u in neighbors} # Initialize Q with the given nodes and subg, cand with their nbrs Q: list[_Node | None] = list(nodes or []) cand = set(neighbors) for node in Q: if node not in cand: raise ValueError(f"The given `nodes` {nodes} do not form a clique") cand &= adj[node] if not cand: yield cast(list[_Node], Q[:]) return subg = cand.copy() stack: list[tuple[set[_Node], set[_Node], set[_Node]]] = [] Q.append(None) u = max(subg, key=lambda u: len(cand & adj[u])) ext_u = cand - adj[u] try: while True: if ext_u: q = ext_u.pop() cand.remove(q) Q[-1] = q adj_q = adj[q] subg_q = subg & adj_q if not subg_q: yield cast(list[_Node], Q[:]) else: cand_q = cand & adj_q if cand_q: stack.append((subg, cand, ext_u)) Q.append(None) subg = subg_q cand = cand_q u = max(subg, key=lambda u: len(cand & adj[u])) ext_u = cand - adj[u] else: Q.pop() subg, cand, ext_u = stack.pop() except IndexError: pass