This commit is contained in:
Mikael CAPELLE 2023-12-12 19:03:40 +01:00
parent c7d2eb2171
commit 13b15aba76

View File

@ -1,125 +1,11 @@
import itertools
import sys
from collections import defaultdict
from dataclasses import dataclass
from functools import lru_cache
from multiprocessing.pool import ThreadPool
from typing import Iterator
from tqdm import tqdm
lines = sys.stdin.read().splitlines()
def v1(pattern: str, counts: list[int]) -> int:
count = 0
missing = [i for i in range(len(pattern)) if pattern[i] == "?"]
for replacements in itertools.product(".#", repeat=len(missing)):
c_pattern = list(pattern)
for i_missing, replacement in zip(missing, replacements):
c_pattern[i_missing] = replacement
parts = [p for p in "".join(c_pattern).split(".") if p]
if len(parts) == len(counts) and all(
len(p) == c for p, c in zip(parts, counts)
):
count += 1
return count
def fn1(pattern: str, counts: list[int]) -> Iterator[tuple[str, list[int]]]:
if not pattern:
yield "", counts
return
if not counts:
if pattern.find("#") == -1:
yield "." * len(pattern), []
return
if sum(counts) + len(counts) - 1 == len(pattern):
c_pattern = ".".join("#" * c for c in counts)
if all(pattern[i] == "?" for i in range(len(pattern)) if c_pattern[i] == "."):
yield c_pattern, []
for other_pattern, other_rest in fn1(pattern, counts[:-1]):
yield other_pattern, other_rest + [counts[-1]]
return
if pattern.find("#") == -1:
yield "." * len(pattern), counts
if len(pattern) < counts[0]:
return
if pattern.find("?") == -1 and counts[0] == len(pattern):
yield pattern, counts[1:]
return
for i in range(len(pattern)):
c_pattern = "." * i + "#" * counts[0]
if len(c_pattern) > len(pattern):
break
if len(c_pattern) < len(pattern):
c_pattern += "."
# print(">", c_pattern, counts[0])
if all(pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "."):
for other_pattern, other_rest in fn1(pattern[len(c_pattern) :], counts[1:]):
# print(">>", other_pattern, other_rest)
yield c_pattern + other_pattern, other_rest
def fn2(
patterns: list[str], counts: list[int], depth: int = 0
) -> Iterator[tuple[str, ...]]:
if not patterns:
if not counts:
yield ()
return
with_hash = sum(1 for p in patterns[1:] if p.find("#") >= 0)
# print(patterns, counts, with_hash, counts[:-with_hash])
to_fit = counts if with_hash == 0 else counts[:-with_hash]
remaining = [] if with_hash == 0 else counts[-with_hash:]
for fp, fc in fn1(patterns[0], to_fit):
# print(f'{"|" * depth} {patterns[0]} ({to_fit}): {fp}, {fc}')
for fp2 in fn2(patterns[1:], fc + remaining, depth + 1):
# if depth == 0:
# print((fp,) + fp2)
yield (fp,) + fp2
def v2(pattern: str, counts: list[int]) -> int:
blocks = list(filter(len, pattern.split(".")))
# print("---")
# print(i_line, parts[0])
# print("--")
# print("---")
# print(parts[0], counts[: len(counts) // repeat])
fillers = list(fn2(blocks, counts))
# for filler in fillers:
# assert len(filler) == len(blocks)
# for f, b in zip(filler, blocks):
# assert all(fi == bi or bi == "?" for fi, bi in zip(f, b))
# # print(b, f)
# hashes = list(filter(len, ".".join(filler).split(".")))
# print(hashes, counts)
# print(len(fillers))
return len(fillers)
@lru_cache
def fn3p(pattern: str, counts: tuple[int, ...]) -> int:
"""
@ -140,52 +26,36 @@ def fn3p(pattern: str, counts: tuple[int, ...]) -> int:
else:
count = 0
# if pattern[0] == "?":
# count += fn3p(pattern[1:], counts)
# count += fn3p(pattern[counts[0] + 1], counts[1:])
if pattern[0] == "?":
count += fn3p(pattern[1:], counts)
for i in range(len(pattern)):
c_pattern = "." * i + "#" * counts[0]
if len(pattern) == counts[0]:
count += 1
if len(c_pattern) > len(pattern):
break
if len(c_pattern) < len(pattern):
c_pattern += "."
# print(">", c_pattern, counts[0])
if all(
pattern[i] == "?" for i in range(len(c_pattern)) if c_pattern[i] == "."
):
count += fn3p(pattern[len(c_pattern) :], counts[1:])
elif pattern[counts[0]] != "#":
count += fn3p(pattern[counts[0] + 1 :], counts[1:])
return count
def fn3(pattern: str, counts: tuple[int, ...]) -> Iterator[tuple[int, tuple[int, ...]]]:
@lru_cache
def fn3(pattern: str, counts: tuple[int, ...]) -> list[tuple[int, tuple[int, ...]]]:
# empty pattern
if not pattern:
yield 1, counts
return
return [(1, counts)]
if not counts:
if pattern.find("#") == -1:
yield 1, ()
return
elif not counts:
return [(1, ())] if pattern.find("#") == -1 else []
if pattern.find("#") != -1 and len(pattern) < counts[0]:
return
elif pattern.find("#") != -1 and len(pattern) < counts[0]:
return []
if pattern.find("?") == -1 and counts[0] == len(pattern):
yield 1, counts[1:]
return
elif pattern.find("?") == -1 and counts[0] == len(pattern):
return [(1, counts[1:])]
for i in range(len(counts) + 1):
to_fit = counts[:i]
yield fn3p(pattern, to_fit), counts[i:]
else:
return [(fn3p(pattern, counts[:i]), counts[i:]) for i in range(len(counts) + 1)]
@lru_cache
@ -218,58 +88,14 @@ def v3(pattern: str, counts: list[int]) -> int:
return fn4(tuple(blocks), tuple(counts))
# print(list(fn1("??#?", [2, 1])))
# print(list(fn2(["??????", "??"], [2, 1, 1])))
# exit()
# r = 1
# print(v2("?".join("??.???#.???" for _ in range(r)), [4, 1] * r))
# print(v3("?".join("??.???#.???" for _ in range(r)), [4, 1] * r))
# exit()
# r = 3
# print(v2("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r))
# print(v3("?".join(".??..??...?##." for _ in range(r)), [1, 1, 3] * r))
# exit()
all_blocks = []
set_blocks = set()
for line in lines:
blocks = list(filter(len, (line.split()[0] * 5).split(".")))
all_blocks.extend(blocks)
set_blocks.update(blocks)
print(len(set_blocks), len(all_blocks))
def compute_possible_arrangements(repeat: int) -> int:
# with tqdm(total=len(lines)) as pbar:
# def _fn(line: str) -> int:
# parts = line.split(" ")
# pattern = "?".join(parts[0] for _ in range(repeat))
# counts = [int(c) for c in parts[1].split(",")] * repeat
# count = v2(pattern, counts)
# pbar.update(1)
# return count
# with ThreadPool() as tp:
# return sum(tp.imap_unordered(_fn, lines))
count = 0
for i_line, line in enumerate(tqdm(lines)):
for line in lines:
parts = line.split(" ")
pattern = "?".join(parts[0] for _ in range(repeat))
counts = [int(c) for c in parts[1].split(",")] * repeat
# c1 = v1(pattern, counts)
# c2 = v2(pattern, counts)
c3 = v3(pattern, counts)
# if c2 != c3:
# print(i_line, line, counts, c2, c3)
count += c3
count += v3(pattern, counts)
return count