From c7d2eb21718ee0032f566349e23840309783a4b8 Mon Sep 17 00:00:00 2001 From: Mikael CAPELLE Date: Tue, 12 Dec 2023 18:54:08 +0100 Subject: [PATCH] WIP 2. --- 2023/day12.py | 227 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 202 insertions(+), 25 deletions(-) diff --git a/2023/day12.py b/2023/day12.py index 4fccea6..402d462 100644 --- a/2023/day12.py +++ b/2023/day12.py @@ -2,6 +2,8 @@ import itertools import sys from collections import defaultdict from dataclasses import dataclass +from functools import lru_cache +from multiprocessing.pool import ThreadPool from typing import Iterator from tqdm import tqdm @@ -9,6 +11,25 @@ from tqdm import tqdm lines = sys.stdin.read().splitlines() +def v1(pattern: str, counts: list[int]) -> int: + count = 0 + missing = [i for i in range(len(pattern)) if pattern[i] == "?"] + + for replacements in itertools.product(".#", repeat=len(missing)): + c_pattern = list(pattern) + for i_missing, replacement in zip(missing, replacements): + c_pattern[i_missing] = replacement + + parts = [p for p in "".join(c_pattern).split(".") if p] + + if len(parts) == len(counts) and all( + len(p) == c for p, c in zip(parts, counts) + ): + count += 1 + + return count + + def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: if not pattern: yield "", counts @@ -23,7 +44,9 @@ def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: c_pattern = ".".join("#" * c for c in counts) if all(pattern[i] == "?" for i in range(len(pattern)) if c_pattern[i] == "."): yield c_pattern, [] - return + for other_pattern, other_rest in fn1(pattern, counts[:-1]): + yield other_pattern, other_rest + [counts[-1]] + return if pattern.find("#") == -1: yield "." * len(pattern), counts @@ -33,6 +56,7 @@ def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: if pattern.find("?") == -1 and counts[0] == len(pattern): yield pattern, counts[1:] + return for i in range(len(pattern)): c_pattern = "." * i + "#" * counts[0] @@ -55,7 +79,8 @@ def fn2( patterns: list[str], counts: list[int], depth: int = 0 ) -> Iterator[tuple[str, ...]]: if not patterns: - yield () + if not counts: + yield () return with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) @@ -65,42 +90,194 @@ def fn2( remaining = [] if with_hash == 0 else counts[-with_hash:] for fp, fc in fn1(patterns[0], to_fit): - if depth == 0: - ... # print(fp, fc) + # print(f'{"|" * depth} {patterns[0]} ({to_fit}): {fp}, {fc}') for fp2 in fn2(patterns[1:], fc + remaining, depth + 1): - # print(fp, fc) + # if depth == 0: + # print((fp,) + fp2) yield (fp,) + fp2 -# print(list(fn1("??", [1, 1, 3, 1]))) -# exit() - - -repeat = 5 -count = 0 -for i_line, line in enumerate(lines): - parts = line.split(" ") - pattern = "?".join(parts[0] for _ in range(repeat)) - counts = [int(c) for c in parts[1].split(",")] * repeat +def v2(pattern: str, counts: list[int]) -> int: + blocks = list(filter(len, pattern.split("."))) # print("---") # print(i_line, parts[0]) # print("--") - print("---") - print(parts[0], counts[: len(counts) // 5]) - r1 = list(fn2(list(filter(len, pattern.split("."))), counts)) - r2 = set(r1) - print(len(r1), len(r2)) # , r2) - count += len(r1) + # print("---") + # print(parts[0], counts[: len(counts) // repeat]) + fillers = list(fn2(blocks, counts)) - # if i_line == 1: - # break + # for filler in fillers: + # assert len(filler) == len(blocks) + # for f, b in zip(filler, blocks): + # assert all(fi == bi or bi == "?" for fi, bi in zip(f, b)) + # # print(b, f) + + # hashes = list(filter(len, ".".join(filler).split("."))) + # print(hashes, counts) + + # print(len(fillers)) + return len(fillers) + + +@lru_cache +def fn3p(pattern: str, counts: tuple[int, ...]) -> int: + """ + fn3p tries to fit ALL values in counts() inside the pattern. + """ + + # no pattern -> ok if nothing to fit, otherwise ko + if not pattern: + count = 1 if not counts else 0 + + # no count -> ok if pattern has no mandatory entry, else ko + elif not counts: + count = 1 if pattern.find("#") == -1 else 0 + + # cannot fit all values -> ko + elif len(pattern) < sum(counts) + len(counts) - 1: + count = 0 + + else: + count = 0 + # if pattern[0] == "?": + # count += fn3p(pattern[1:], counts) + + # count += fn3p(pattern[counts[0] + 1], counts[1:]) + + for i in range(len(pattern)): + c_pattern = "." * i + "#" * counts[0] + + if len(c_pattern) > len(pattern): + break + + if len(c_pattern) < len(pattern): + c_pattern += "." + + # print(">", c_pattern, counts[0]) + + if all( + pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "." + ): + count += fn3p(pattern[len(c_pattern) :], counts[1:]) + + return count + + +def fn3(pattern: str, counts: tuple[int, ...]) -> Iterator[tuple[int, tuple[int, ...]]]: + # empty pattern + if not pattern: + yield 1, counts + return + + if not counts: + if pattern.find("#") == -1: + yield 1, () + return + + if pattern.find("#") != -1 and len(pattern) < counts[0]: + return + + if pattern.find("?") == -1 and counts[0] == len(pattern): + yield 1, counts[1:] + return + + for i in range(len(counts) + 1): + to_fit = counts[:i] + + yield fn3p(pattern, to_fit), counts[i:] + + +@lru_cache +def fn4(patterns: tuple[str], counts: tuple[int, ...], depth: int = 0) -> int: + if not patterns: + if not counts: + return 1 + return 0 + + with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) + + if with_hash > len(counts): + return 0 + + to_fit = counts if with_hash == 0 else counts[:-with_hash] + remaining = () if with_hash == 0 else counts[-with_hash:] + + count = 0 + for fp, fc in fn3(patterns[0], to_fit): + if fp == 0: + continue + # print("|" * depth, patterns[0], to_fit, remaining, fp, fc) + count += fp * fn4(patterns[1:], fc + remaining, depth + 1) + + return count + + +def v3(pattern: str, counts: list[int]) -> int: + blocks = list(filter(len, pattern.split("."))) + return fn4(tuple(blocks), tuple(counts)) + + +# print(list(fn1("??#?", [2, 1]))) +# print(list(fn2(["??????", "??"], [2, 1, 1]))) +# exit() + +# r = 1 +# print(v2("?".join("??.???#.???" for _ in range(r)), [4, 1] * r)) +# print(v3("?".join("??.???#.???" for _ in range(r)), [4, 1] * r)) +# exit() + +# r = 3 +# print(v2("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r)) +# print(v3("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r)) +# exit() + +all_blocks = [] +set_blocks = set() +for line in lines: + blocks = list(filter(len, (line.split()[0] * 5).split("."))) + all_blocks.extend(blocks) + set_blocks.update(blocks) + +print(len(set_blocks), len(all_blocks)) + + +def compute_possible_arrangements(repeat: int) -> int: + # with tqdm(total=len(lines)) as pbar: + + # def _fn(line: str) -> int: + # parts = line.split(" ") + # pattern = "?".join(parts[0] for _ in range(repeat)) + # counts = [int(c) for c in parts[1].split(",")] * repeat + # count = v2(pattern, counts) + # pbar.update(1) + # return count + + # with ThreadPool() as tp: + # return sum(tp.imap_unordered(_fn, lines)) + + count = 0 + + for i_line, line in enumerate(tqdm(lines)): + parts = line.split(" ") + pattern = "?".join(parts[0] for _ in range(repeat)) + counts = [int(c) for c in parts[1].split(",")] * repeat + # c1 = v1(pattern, counts) + # c2 = v2(pattern, counts) + c3 = v3(pattern, counts) + + # if c2 != c3: + # print(i_line, line, counts, c2, c3) + + count += c3 + + return count # part 1 -answer_1 = count +answer_1 = compute_possible_arrangements(1) print(f"answer 1 is {answer_1}") # part 2 -answer_2 = ... +answer_2 = compute_possible_arrangements(5) print(f"answer 2 is {answer_2}")