import itertools import sys from collections import defaultdict from dataclasses import dataclass from typing import Iterator from tqdm import tqdm lines = sys.stdin.read().splitlines() def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: if not pattern: yield "", counts return if not counts: if pattern.find("#") == -1: yield "." * len(pattern), [] return if sum(counts) + len(counts) - 1 == len(pattern): c_pattern = ".".join("#" * c for c in counts) if all(pattern[i] == "?" for i in range(len(pattern)) if c_pattern[i] == "."): yield c_pattern, [] return if pattern.find("#") == -1: yield "." * len(pattern), counts if len(pattern) < counts[0]: return if pattern.find("?") == -1 and counts[0] == len(pattern): yield pattern, counts[1:] for i in range(len(pattern)): c_pattern = "." * i + "#" * counts[0] if len(c_pattern) > len(pattern): break if len(c_pattern) < len(pattern): c_pattern += "." # print(">", c_pattern, counts[0]) if all(pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "."): for other_pattern, other_rest in fn1(pattern[len(c_pattern) :], counts[1:]): # print(">>", other_pattern, other_rest) yield c_pattern + other_pattern, other_rest def fn2( patterns: list[str], counts: list[int], depth: int = 0 ) -> Iterator[tuple[str, ...]]: if not patterns: yield () return with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) # print(patterns, counts, with_hash, counts[:-with_hash]) to_fit = counts if with_hash == 0 else counts[:-with_hash] remaining = [] if with_hash == 0 else counts[-with_hash:] for fp, fc in fn1(patterns[0], to_fit): if depth == 0: ... # print(fp, fc) for fp2 in fn2(patterns[1:], fc + remaining, depth + 1): # print(fp, fc) yield (fp,) + fp2 # print(list(fn1("??", [1, 1, 3, 1]))) # exit() repeat = 5 count = 0 for i_line, line in enumerate(lines): parts = line.split(" ") pattern = "?".join(parts[0] for _ in range(repeat)) counts = [int(c) for c in parts[1].split(",")] * repeat # print("---") # print(i_line, parts[0]) # print("--") print("---") print(parts[0], counts[: len(counts) // 5]) r1 = list(fn2(list(filter(len, pattern.split("."))), counts)) r2 = set(r1) print(len(r1), len(r2)) # , r2) count += len(r1) # if i_line == 1: # break # part 1 answer_1 = count print(f"answer 1 is {answer_1}") # part 2 answer_2 = ... print(f"answer 2 is {answer_2}")