diff --git a/2023/day12.py b/2023/day12.py index 402d462..331200f 100644 --- a/2023/day12.py +++ b/2023/day12.py @@ -1,125 +1,11 @@ import itertools import sys -from collections import defaultdict -from dataclasses import dataclass from functools import lru_cache -from multiprocessing.pool import ThreadPool from typing import Iterator -from tqdm import tqdm - lines = sys.stdin.read().splitlines() -def v1(pattern: str, counts: list[int]) -> int: - count = 0 - missing = [i for i in range(len(pattern)) if pattern[i] == "?"] - - for replacements in itertools.product(".#", repeat=len(missing)): - c_pattern = list(pattern) - for i_missing, replacement in zip(missing, replacements): - c_pattern[i_missing] = replacement - - parts = [p for p in "".join(c_pattern).split(".") if p] - - if len(parts) == len(counts) and all( - len(p) == c for p, c in zip(parts, counts) - ): - count += 1 - - return count - - -def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: - if not pattern: - yield "", counts - return - - if not counts: - if pattern.find("#") == -1: - yield "." * len(pattern), [] - return - - if sum(counts) + len(counts) - 1 == len(pattern): - c_pattern = ".".join("#" * c for c in counts) - if all(pattern[i] == "?" for i in range(len(pattern)) if c_pattern[i] == "."): - yield c_pattern, [] - for other_pattern, other_rest in fn1(pattern, counts[:-1]): - yield other_pattern, other_rest + [counts[-1]] - return - - if pattern.find("#") == -1: - yield "." * len(pattern), counts - - if len(pattern) < counts[0]: - return - - if pattern.find("?") == -1 and counts[0] == len(pattern): - yield pattern, counts[1:] - return - - for i in range(len(pattern)): - c_pattern = "." * i + "#" * counts[0] - - if len(c_pattern) > len(pattern): - break - - if len(c_pattern) < len(pattern): - c_pattern += "." - - # print(">", c_pattern, counts[0]) - - if all(pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "."): - for other_pattern, other_rest in fn1(pattern[len(c_pattern) :], counts[1:]): - # print(">>", other_pattern, other_rest) - yield c_pattern + other_pattern, other_rest - - -def fn2( - patterns: list[str], counts: list[int], depth: int = 0 -) -> Iterator[tuple[str, ...]]: - if not patterns: - if not counts: - yield () - return - - with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) - # print(patterns, counts, with_hash, counts[:-with_hash]) - - to_fit = counts if with_hash == 0 else counts[:-with_hash] - remaining = [] if with_hash == 0 else counts[-with_hash:] - - for fp, fc in fn1(patterns[0], to_fit): - # print(f'{"|" * depth} {patterns[0]} ({to_fit}): {fp}, {fc}') - for fp2 in fn2(patterns[1:], fc + remaining, depth + 1): - # if depth == 0: - # print((fp,) + fp2) - yield (fp,) + fp2 - - -def v2(pattern: str, counts: list[int]) -> int: - blocks = list(filter(len, pattern.split("."))) - - # print("---") - # print(i_line, parts[0]) - # print("--") - # print("---") - # print(parts[0], counts[: len(counts) // repeat]) - fillers = list(fn2(blocks, counts)) - - # for filler in fillers: - # assert len(filler) == len(blocks) - # for f, b in zip(filler, blocks): - # assert all(fi == bi or bi == "?" for fi, bi in zip(f, b)) - # # print(b, f) - - # hashes = list(filter(len, ".".join(filler).split("."))) - # print(hashes, counts) - - # print(len(fillers)) - return len(fillers) - - @lru_cache def fn3p(pattern: str, counts: tuple[int, ...]) -> int: """ @@ -140,52 +26,36 @@ def fn3p(pattern: str, counts: tuple[int, ...]) -> int: else: count = 0 - # if pattern[0] == "?": - # count += fn3p(pattern[1:], counts) - # count += fn3p(pattern[counts[0] + 1], counts[1:]) + if pattern[0] == "?": + count += fn3p(pattern[1:], counts) - for i in range(len(pattern)): - c_pattern = "." * i + "#" * counts[0] + if len(pattern) == counts[0]: + count += 1 - if len(c_pattern) > len(pattern): - break - - if len(c_pattern) < len(pattern): - c_pattern += "." - - # print(">", c_pattern, counts[0]) - - if all( - pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "." - ): - count += fn3p(pattern[len(c_pattern) :], counts[1:]) + elif pattern[counts[0]] != "#": + count += fn3p(pattern[counts[0] + 1 :], counts[1:]) return count -def fn3(pattern: str, counts: tuple[int, ...]) -> Iterator[tuple[int, tuple[int, ...]]]: +@lru_cache +def fn3(pattern: str, counts: tuple[int, ...]) -> list[tuple[int, tuple[int, ...]]]: # empty pattern if not pattern: - yield 1, counts - return + return [(1, counts)] - if not counts: - if pattern.find("#") == -1: - yield 1, () - return + elif not counts: + return [(1, ())] if pattern.find("#") == -1 else [] - if pattern.find("#") != -1 and len(pattern) < counts[0]: - return + elif pattern.find("#") != -1 and len(pattern) < counts[0]: + return [] - if pattern.find("?") == -1 and counts[0] == len(pattern): - yield 1, counts[1:] - return + elif pattern.find("?") == -1 and counts[0] == len(pattern): + return [(1, counts[1:])] - for i in range(len(counts) + 1): - to_fit = counts[:i] - - yield fn3p(pattern, to_fit), counts[i:] + else: + return [(fn3p(pattern, counts[:i]), counts[i:]) for i in range(len(counts) + 1)] @lru_cache @@ -218,58 +88,14 @@ def v3(pattern: str, counts: list[int]) -> int: return fn4(tuple(blocks), tuple(counts)) -# print(list(fn1("??#?", [2, 1]))) -# print(list(fn2(["??????", "??"], [2, 1, 1]))) -# exit() - -# r = 1 -# print(v2("?".join("??.???#.???" for _ in range(r)), [4, 1] * r)) -# print(v3("?".join("??.???#.???" for _ in range(r)), [4, 1] * r)) -# exit() - -# r = 3 -# print(v2("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r)) -# print(v3("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r)) -# exit() - -all_blocks = [] -set_blocks = set() -for line in lines: - blocks = list(filter(len, (line.split()[0] * 5).split("."))) - all_blocks.extend(blocks) - set_blocks.update(blocks) - -print(len(set_blocks), len(all_blocks)) - - def compute_possible_arrangements(repeat: int) -> int: - # with tqdm(total=len(lines)) as pbar: - - # def _fn(line: str) -> int: - # parts = line.split(" ") - # pattern = "?".join(parts[0] for _ in range(repeat)) - # counts = [int(c) for c in parts[1].split(",")] * repeat - # count = v2(pattern, counts) - # pbar.update(1) - # return count - - # with ThreadPool() as tp: - # return sum(tp.imap_unordered(_fn, lines)) - count = 0 - for i_line, line in enumerate(tqdm(lines)): + for line in lines: parts = line.split(" ") pattern = "?".join(parts[0] for _ in range(repeat)) counts = [int(c) for c in parts[1].split(",")] * repeat - # c1 = v1(pattern, counts) - # c2 = v2(pattern, counts) - c3 = v3(pattern, counts) - - # if c2 != c3: - # print(i_line, line, counts, c2, c3) - - count += c3 + count += v3(pattern, counts) return count