This commit is contained in:
@@ -1,7 +1,50 @@
|
||||
from typing import Any, Iterator
|
||||
import string
|
||||
from collections import defaultdict
|
||||
from functools import cache
|
||||
from typing import Any, Iterator, Mapping, Sequence
|
||||
|
||||
from ..base import BaseSolver
|
||||
|
||||
|
||||
@cache
|
||||
def is_small(node: str):
|
||||
return all(c in string.ascii_lowercase for c in node)
|
||||
|
||||
|
||||
def enumerate_paths(
|
||||
neighbors: Mapping[str, Sequence[str]],
|
||||
duplicate_smalls: int = 0,
|
||||
start: str = "start",
|
||||
current: tuple[str, ...] = ("start",),
|
||||
) -> Iterator[tuple[str, ...]]:
|
||||
if start == "end":
|
||||
yield current
|
||||
|
||||
for neighbor in neighbors[start]:
|
||||
if not is_small(neighbor):
|
||||
yield from enumerate_paths(
|
||||
neighbors, duplicate_smalls, neighbor, current + (neighbor,)
|
||||
)
|
||||
elif neighbor not in current:
|
||||
yield from enumerate_paths(
|
||||
neighbors, duplicate_smalls, neighbor, current + (neighbor,)
|
||||
)
|
||||
elif duplicate_smalls > 0:
|
||||
yield from enumerate_paths(
|
||||
neighbors, duplicate_smalls - 1, neighbor, current + (neighbor,)
|
||||
)
|
||||
|
||||
|
||||
class Solver(BaseSolver):
|
||||
def solve(self, input: str) -> Iterator[Any]: ...
|
||||
def solve(self, input: str) -> Iterator[Any]:
|
||||
neighbors: dict[str, list[str]] = defaultdict(list)
|
||||
|
||||
for row in input.splitlines():
|
||||
a, b = row.split("-")
|
||||
if a != "end" and b != "start":
|
||||
neighbors[a].append(b)
|
||||
if b != "end" and a != "start":
|
||||
neighbors[b].append(a)
|
||||
|
||||
yield len(list(enumerate_paths(neighbors)))
|
||||
yield len(list(enumerate_paths(neighbors, 1)))
|
||||
|
Reference in New Issue
Block a user