import itertools import sys from functools import lru_cache from typing import Iterator lines = sys.stdin.read().splitlines() @lru_cache def fn3p(pattern: str, counts: tuple[int, ...]) -> int: """ fn3p tries to fit ALL values in counts() inside the pattern. """ # no pattern -> ok if nothing to fit, otherwise ko if not pattern: count = 1 if not counts else 0 # no count -> ok if pattern has no mandatory entry, else ko elif not counts: count = 1 if pattern.find("#") == -1 else 0 # cannot fit all values -> ko elif len(pattern) < sum(counts) + len(counts) - 1: count = 0 else: count = 0 if pattern[0] == "?": count += fn3p(pattern[1:], counts) if len(pattern) == counts[0]: count += 1 elif pattern[counts[0]] != "#": count += fn3p(pattern[counts[0] + 1 :], counts[1:]) return count @lru_cache def fn3(pattern: str, counts: tuple[int, ...]) -> list[tuple[int, tuple[int, ...]]]: # empty pattern if not pattern: return [(1, counts)] elif not counts: return [(1, ())] if pattern.find("#") == -1 else [] elif pattern.find("#") != -1 and len(pattern) < counts[0]: return [] elif pattern.find("?") == -1 and counts[0] == len(pattern): return [(1, counts[1:])] else: return [(fn3p(pattern, counts[:i]), counts[i:]) for i in range(len(counts) + 1)] @lru_cache def fn4(patterns: tuple[str], counts: tuple[int, ...], depth: int = 0) -> int: if not patterns: if not counts: return 1 return 0 with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) if with_hash > len(counts): return 0 to_fit = counts if with_hash == 0 else counts[:-with_hash] remaining = () if with_hash == 0 else counts[-with_hash:] count = 0 for fp, fc in fn3(patterns[0], to_fit): if fp == 0: continue # print("|" * depth, patterns[0], to_fit, remaining, fp, fc) count += fp * fn4(patterns[1:], fc + remaining, depth + 1) return count def v3(pattern: str, counts: list[int]) -> int: blocks = list(filter(len, pattern.split("."))) return fn4(tuple(blocks), tuple(counts)) def compute_possible_arrangements(repeat: int) -> int: count = 0 for line in lines: parts = line.split(" ") pattern = "?".join(parts[0] for _ in range(repeat)) counts = [int(c) for c in parts[1].split(",")] * repeat count += v3(pattern, counts) return count # part 1 answer_1 = compute_possible_arrangements(1) print(f"answer 1 is {answer_1}") # part 2 answer_2 = compute_possible_arrangements(5) print(f"answer 2 is {answer_2}")