From edc50cb9c247ac3fd765bed8818df6d88833ffb5 Mon Sep 17 00:00:00 2001 From: Mikael CAPELLE Date: Tue, 12 Dec 2023 14:26:07 +0100 Subject: [PATCH] WIP. --- 2023/day12.py | 100 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 84 insertions(+), 16 deletions(-) diff --git a/2023/day12.py b/2023/day12.py index 73569c3..4fccea6 100644 --- a/2023/day12.py +++ b/2023/day12.py @@ -2,31 +2,100 @@ import itertools import sys from collections import defaultdict from dataclasses import dataclass +from typing import Iterator from tqdm import tqdm lines = sys.stdin.read().splitlines() + +def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]: + if not pattern: + yield "", counts + return + + if not counts: + if pattern.find("#") == -1: + yield "." * len(pattern), [] + return + + if sum(counts) + len(counts) - 1 == len(pattern): + c_pattern = ".".join("#" * c for c in counts) + if all(pattern[i] == "?" for i in range(len(pattern)) if c_pattern[i] == "."): + yield c_pattern, [] + return + + if pattern.find("#") == -1: + yield "." * len(pattern), counts + + if len(pattern) < counts[0]: + return + + if pattern.find("?") == -1 and counts[0] == len(pattern): + yield pattern, counts[1:] + + for i in range(len(pattern)): + c_pattern = "." * i + "#" * counts[0] + + if len(c_pattern) > len(pattern): + break + + if len(c_pattern) < len(pattern): + c_pattern += "." + + # print(">", c_pattern, counts[0]) + + if all(pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "."): + for other_pattern, other_rest in fn1(pattern[len(c_pattern) :], counts[1:]): + # print(">>", other_pattern, other_rest) + yield c_pattern + other_pattern, other_rest + + +def fn2( + patterns: list[str], counts: list[int], depth: int = 0 +) -> Iterator[tuple[str, ...]]: + if not patterns: + yield () + return + + with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0) + # print(patterns, counts, with_hash, counts[:-with_hash]) + + to_fit = counts if with_hash == 0 else counts[:-with_hash] + remaining = [] if with_hash == 0 else counts[-with_hash:] + + for fp, fc in fn1(patterns[0], to_fit): + if depth == 0: + ... # print(fp, fc) + for fp2 in fn2(patterns[1:], fc + remaining, depth + 1): + # print(fp, fc) + yield (fp,) + fp2 + + +# print(list(fn1("??", [1, 1, 3, 1]))) +# exit() + + +repeat = 5 count = 0 -for line in lines: +for i_line, line in enumerate(lines): parts = line.split(" ") - pattern = parts[0] - counts = [int(c) for c in parts[1].split(",")] + pattern = "?".join(parts[0] for _ in range(repeat)) + counts = [int(c) for c in parts[1].split(",")] * repeat - missing = [i for i in range(len(pattern)) if pattern[i] == "?"] + # print("---") + # print(i_line, parts[0]) + # print("--") + print("---") + print(parts[0], counts[: len(counts) // 5]) + r1 = list(fn2(list(filter(len, pattern.split("."))), counts)) + r2 = set(r1) + print(len(r1), len(r2)) # , r2) + count += len(r1) - for replacements in itertools.product(".#", repeat=len(missing)): - c_pattern = list(pattern) - for i_missing, replacement in zip(missing, replacements): - c_pattern[i_missing] = replacement + # if i_line == 1: + # break - parts = [p for p in "".join(c_pattern).split(".") if p] - - if len(parts) == len(counts) and all( - len(p) == c for p, c in zip(parts, counts) - ): - # print("".join(c_pattern), counts) - count += 1 # part 1 answer_1 = count @@ -35,4 +104,3 @@ print(f"answer 1 is {answer_1}") # part 2 answer_2 = ... print(f"answer 2 is {answer_2}") -print(f"answer 2 is {answer_2}")