import string from collections import defaultdict from functools import cache from typing import Any, Iterator, Mapping, Sequence from ..base import BaseSolver @cache def is_small(node: str): return all(c in string.ascii_lowercase for c in node) def enumerate_paths( neighbors: Mapping[str, Sequence[str]], duplicate_smalls: int = 0, start: str = "start", current: tuple[str, ...] = ("start",), ) -> Iterator[tuple[str, ...]]: if start == "end": yield current for neighbor in neighbors[start]: if not is_small(neighbor): yield from enumerate_paths( neighbors, duplicate_smalls, neighbor, current + (neighbor,) ) elif neighbor not in current: yield from enumerate_paths( neighbors, duplicate_smalls, neighbor, current + (neighbor,) ) elif duplicate_smalls > 0: yield from enumerate_paths( neighbors, duplicate_smalls - 1, neighbor, current + (neighbor,) ) class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: neighbors: dict[str, list[str]] = defaultdict(list) for row in input.splitlines(): a, b = row.split("-") if a != "end" and b != "start": neighbors[a].append(b) if b != "end" and a != "start": neighbors[b].append(a) if self.files: graph = "graph {\n" for node, neighbors_of in neighbors.items(): graph += ( " ".join( f"{node} -- {neighbor};" for neighbor in neighbors_of if node <= neighbor or node == "start" or neighbor == "end" ) + "\n" ) graph += "}\n" self.files.create("graph.dot", graph.encode(), False) yield len(list(enumerate_paths(neighbors))) yield len(list(enumerate_paths(neighbors, 1)))