import itertools import sys from collections import defaultdict from dataclasses import dataclass from functools import lru_cache from multiprocessing.pool import ThreadPool from typing import Iterator from tqdm import tqdm lines = sys.stdin.read().splitlines() def v1(pattern: str, counts: list[int]) -> int: count = 0 missing = [i for i in range(len(pattern)) if pattern[i] == "?"] for replacements in itertools.product(".#", repeat=len(missing)): c_pattern = list(pattern) for i_missing, replacement in zip(missing, replacements): c_pattern[i_missing] = replacement parts = [p for p in "".join(c_pattern).split(".") if p] if len(parts) == len(counts) and all( len(p) == c for p, c in zip(parts, counts) ): count += 1 return count def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: if not pattern: yield "", counts return if not counts: if pattern.find("#") == -1: yield "." * len(pattern), [] return if sum(counts) + len(counts) - 1 == len(pattern): c_pattern = ".".join("#" * c for c in counts) if all(pattern[i] == "?" for i in range(len(pattern)) if c_pattern[i] == "."): yield c_pattern, [] for other_pattern, other_rest in fn1(pattern, counts[:-1]): yield other_pattern, other_rest + [counts[-1]] return if pattern.find("#") == -1: yield "." * len(pattern), counts if len(pattern) < counts[0]: return if pattern.find("?") == -1 and counts[0] == len(pattern): yield pattern, counts[1:] return for i in range(len(pattern)): c_pattern = "." * i + "#" * counts[0] if len(c_pattern) > len(pattern): break if len(c_pattern) < len(pattern): c_pattern += "." # print(">", c_pattern, counts[0]) if all(pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "."): for other_pattern, other_rest in fn1(pattern[len(c_pattern) :], counts[1:]): # print(">>", other_pattern, other_rest) yield c_pattern + other_pattern, other_rest def fn2( patterns: list[str], counts: list[int], depth: int = 0 ) -> Iterator[tuple[str, ...]]: if not patterns: if not counts: yield () return with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) # print(patterns, counts, with_hash, counts[:-with_hash]) to_fit = counts if with_hash == 0 else counts[:-with_hash] remaining = [] if with_hash == 0 else counts[-with_hash:] for fp, fc in fn1(patterns[0], to_fit): # print(f'{"|" * depth} {patterns[0]} ({to_fit}): {fp}, {fc}') for fp2 in fn2(patterns[1:], fc + remaining, depth + 1): # if depth == 0: # print((fp,) + fp2) yield (fp,) + fp2 def v2(pattern: str, counts: list[int]) -> int: blocks = list(filter(len, pattern.split("."))) # print("---") # print(i_line, parts[0]) # print("--") # print("---") # print(parts[0], counts[: len(counts) // repeat]) fillers = list(fn2(blocks, counts)) # for filler in fillers: # assert len(filler) == len(blocks) # for f, b in zip(filler, blocks): # assert all(fi == bi or bi == "?" for fi, bi in zip(f, b)) # # print(b, f) # hashes = list(filter(len, ".".join(filler).split("."))) # print(hashes, counts) # print(len(fillers)) return len(fillers) @lru_cache def fn3p(pattern: str, counts: tuple[int, ...]) -> int: """ fn3p tries to fit ALL values in counts() inside the pattern. """ # no pattern -> ok if nothing to fit, otherwise ko if not pattern: count = 1 if not counts else 0 # no count -> ok if pattern has no mandatory entry, else ko elif not counts: count = 1 if pattern.find("#") == -1 else 0 # cannot fit all values -> ko elif len(pattern) < sum(counts) + len(counts) - 1: count = 0 else: count = 0 # if pattern[0] == "?": # count += fn3p(pattern[1:], counts) # count += fn3p(pattern[counts[0] + 1], counts[1:]) for i in range(len(pattern)): c_pattern = "." * i + "#" * counts[0] if len(c_pattern) > len(pattern): break if len(c_pattern) < len(pattern): c_pattern += "." # print(">", c_pattern, counts[0]) if all( pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "." ): count += fn3p(pattern[len(c_pattern) :], counts[1:]) return count def fn3(pattern: str, counts: tuple[int, ...]) -> Iterator[tuple[int, tuple[int, ...]]]: # empty pattern if not pattern: yield 1, counts return if not counts: if pattern.find("#") == -1: yield 1, () return if pattern.find("#") != -1 and len(pattern) < counts[0]: return if pattern.find("?") == -1 and counts[0] == len(pattern): yield 1, counts[1:] return for i in range(len(counts) + 1): to_fit = counts[:i] yield fn3p(pattern, to_fit), counts[i:] @lru_cache def fn4(patterns: tuple[str], counts: tuple[int, ...], depth: int = 0) -> int: if not patterns: if not counts: return 1 return 0 with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) if with_hash > len(counts): return 0 to_fit = counts if with_hash == 0 else counts[:-with_hash] remaining = () if with_hash == 0 else counts[-with_hash:] count = 0 for fp, fc in fn3(patterns[0], to_fit): if fp == 0: continue # print("|" * depth, patterns[0], to_fit, remaining, fp, fc) count += fp * fn4(patterns[1:], fc + remaining, depth + 1) return count def v3(pattern: str, counts: list[int]) -> int: blocks = list(filter(len, pattern.split("."))) return fn4(tuple(blocks), tuple(counts)) # print(list(fn1("??#?", [2, 1]))) # print(list(fn2(["??????", "??"], [2, 1, 1]))) # exit() # r = 1 # print(v2("?".join("??.???#.???" for _ in range(r)), [4, 1] * r)) # print(v3("?".join("??.???#.???" for _ in range(r)), [4, 1] * r)) # exit() # r = 3 # print(v2("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r)) # print(v3("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r)) # exit() all_blocks = [] set_blocks = set() for line in lines: blocks = list(filter(len, (line.split()[0] * 5).split("."))) all_blocks.extend(blocks) set_blocks.update(blocks) print(len(set_blocks), len(all_blocks)) def compute_possible_arrangements(repeat: int) -> int: # with tqdm(total=len(lines)) as pbar: # def _fn(line: str) -> int: # parts = line.split(" ") # pattern = "?".join(parts[0] for _ in range(repeat)) # counts = [int(c) for c in parts[1].split(",")] * repeat # count = v2(pattern, counts) # pbar.update(1) # return count # with ThreadPool() as tp: # return sum(tp.imap_unordered(_fn, lines)) count = 0 for i_line, line in enumerate(tqdm(lines)): parts = line.split(" ") pattern = "?".join(parts[0] for _ in range(repeat)) counts = [int(c) for c in parts[1].split(",")] * repeat # c1 = v1(pattern, counts) # c2 = v2(pattern, counts) c3 = v3(pattern, counts) # if c2 != c3: # print(i_line, line, counts, c2, c3) count += c3 return count # part 1 answer_1 = compute_possible_arrangements(1) print(f"answer 1 is {answer_1}") # part 2 answer_2 = compute_possible_arrangements(5) print(f"answer 2 is {answer_2}")