65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
import string
|
|
from collections import defaultdict
|
|
from functools import cache
|
|
from typing import Any, Iterator, Mapping, Sequence
|
|
|
|
from ..base import BaseSolver
|
|
|
|
|
|
@cache
|
|
def is_small(node: str):
|
|
return all(c in string.ascii_lowercase for c in node)
|
|
|
|
|
|
def enumerate_paths(
|
|
neighbors: Mapping[str, Sequence[str]],
|
|
duplicate_smalls: int = 0,
|
|
start: str = "start",
|
|
current: tuple[str, ...] = ("start",),
|
|
) -> Iterator[tuple[str, ...]]:
|
|
if start == "end":
|
|
yield current
|
|
|
|
for neighbor in neighbors[start]:
|
|
if not is_small(neighbor):
|
|
yield from enumerate_paths(
|
|
neighbors, duplicate_smalls, neighbor, current + (neighbor,)
|
|
)
|
|
elif neighbor not in current:
|
|
yield from enumerate_paths(
|
|
neighbors, duplicate_smalls, neighbor, current + (neighbor,)
|
|
)
|
|
elif duplicate_smalls > 0:
|
|
yield from enumerate_paths(
|
|
neighbors, duplicate_smalls - 1, neighbor, current + (neighbor,)
|
|
)
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
def solve(self, input: str) -> Iterator[Any]:
|
|
neighbors: dict[str, list[str]] = defaultdict(list)
|
|
|
|
for row in input.splitlines():
|
|
a, b = row.split("-")
|
|
if a != "end" and b != "start":
|
|
neighbors[a].append(b)
|
|
if b != "end" and a != "start":
|
|
neighbors[b].append(a)
|
|
|
|
if self.files:
|
|
graph = "graph {\n"
|
|
for node, neighbors_of in neighbors.items():
|
|
graph += (
|
|
" ".join(
|
|
f"{node} -- {neighbor};"
|
|
for neighbor in neighbors_of
|
|
if node <= neighbor or node == "start" or neighbor == "end"
|
|
)
|
|
+ "\n"
|
|
)
|
|
graph += "}\n"
|
|
self.files.create("graph.dot", graph.encode(), False)
|
|
|
|
yield len(list(enumerate_paths(neighbors)))
|
|
yield len(list(enumerate_paths(neighbors, 1)))
|