import os import sys from functools import lru_cache from typing import Iterable VERBOSE = os.getenv("AOC_VERBOSE") == "True" @lru_cache def compute_fitting_arrangements(pattern: str, counts: tuple[int, ...]) -> int: """ fn3p tries to fit ALL values in counts() inside the pattern. """ # no pattern -> ok if nothing to fit, otherwise ko if not pattern: count = 1 if not counts else 0 # no count -> ok if pattern has no mandatory entry, else ko elif not counts: count = 1 if pattern.find("#") == -1 else 0 # cannot fit all values -> ko elif len(pattern) < sum(counts) + len(counts) - 1: count = 0 elif len(pattern) < counts[0]: count = 0 else: count = 0 if pattern[0] == "?": count += compute_fitting_arrangements(pattern[1:], counts) if len(pattern) == counts[0]: count += 1 elif pattern[counts[0]] != "#": count += compute_fitting_arrangements(pattern[counts[0] + 1 :], counts[1:]) return count @lru_cache def compute_possible_arrangements( patterns: tuple[str, ...], counts: tuple[int, ...] ) -> int: if not patterns: return 1 if not counts else 0 with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) if with_hash > len(counts): return 0 to_fit = counts if with_hash == 0 else counts[:-with_hash] remaining = () if with_hash == 0 else counts[-with_hash:] if not to_fit: if patterns[0].find("#") != -1: return 0 return compute_possible_arrangements(patterns[1:], remaining) elif patterns[0].find("#") != -1 and len(patterns[0]) < to_fit[0]: return 0 elif patterns[0].find("?") == -1: if len(patterns[0]) != to_fit[0]: return 0 return compute_possible_arrangements(patterns[1:], counts[1:]) else: return sum( fp * compute_possible_arrangements(patterns[1:], to_fit[i:] + remaining) for i in range(len(to_fit) + 1) if (fp := compute_fitting_arrangements(patterns[0], to_fit[:i])) > 0 ) def compute_all_possible_arrangements(lines: Iterable[str], repeat: int) -> int: count = 0 if VERBOSE: from tqdm import tqdm lines = tqdm(lines) for line in lines: parts = line.split(" ") count += compute_possible_arrangements( tuple(filter(len, "?".join(parts[0] for _ in range(repeat)).split("."))), tuple(int(c) for c in parts[1].split(",")) * repeat, ) return count lines = sys.stdin.read().splitlines() # part 1 answer_1 = compute_all_possible_arrangements(lines, 1) print(f"answer 1 is {answer_1}") # part 2 answer_2 = compute_all_possible_arrangements(lines, 5) print(f"answer 2 is {answer_2}")