Files
advent-of-code/src/holt59/aoc/2023/day20.py
2024-12-10 16:54:18 +01:00

173 lines
5.7 KiB
Python

from collections import defaultdict
from math import lcm
from typing import Any, Iterator, Literal, TypeAlias
from ..base import BaseSolver
ModuleType: TypeAlias = Literal["broadcaster", "conjunction", "flip-flop"]
PulseType: TypeAlias = Literal["high", "low"]
class Solver(BaseSolver):
_modules: dict[str, tuple[ModuleType, list[str]]]
def _process(
self,
start: tuple[str, str, PulseType],
flip_flop_states: dict[str, Literal["on", "off"]],
conjunction_states: dict[str, dict[str, PulseType]],
) -> tuple[dict[PulseType, int], dict[str, dict[PulseType, int]]]:
pulses: list[tuple[str, str, PulseType]] = [start]
counts: dict[PulseType, int] = {"low": 0, "high": 0}
inputs: dict[str, dict[PulseType, int]] = defaultdict(
lambda: {"low": 0, "high": 0}
)
self.logger.info("starting process... ")
while pulses:
input, name, pulse = pulses.pop(0)
self.logger.info(f"{input} -{pulse}-> {name}")
counts[pulse] += 1
inputs[name][pulse] += 1
if name not in self._modules:
continue
type, outputs = self._modules[name]
if type == "broadcaster":
...
elif type == "flip-flop":
if pulse == "high":
continue
if flip_flop_states[name] == "off":
flip_flop_states[name] = "on"
pulse = "high"
else:
flip_flop_states[name] = "off"
pulse = "low"
else:
conjunction_states[name][input] = pulse
if all(state == "high" for state in conjunction_states[name].values()):
pulse = "low"
else:
pulse = "high"
pulses.extend((name, output, pulse) for output in outputs)
return counts, inputs
def solve(self, input: str) -> Iterator[Any]:
self._modules = {}
lines = input.splitlines()
for line in lines:
name, outputs_s = line.split(" -> ")
outputs = outputs_s.split(", ")
if name == "broadcaster":
self._modules["broadcaster"] = ("broadcaster", outputs)
else:
self._modules[name[1:]] = (
"conjunction" if name.startswith("&") else "flip-flop",
outputs,
)
if self.files:
contents = "digraph G {\n"
contents += "rx [shape=circle, color=red, style=filled];\n"
for name, (type, outputs) in self._modules.items():
if type == "conjunction":
shape = "diamond"
elif type == "flip-flop":
shape = "box"
else:
shape = "circle"
contents += f"{name} [shape={shape}];\n"
for name, (type, outputs) in self._modules.items():
for output in outputs:
contents += f"{name} -> {output};\n"
contents += "}\n"
self.files.create("day20.dot", contents.encode(), False)
# part 1
flip_flop_states: dict[str, Literal["on", "off"]] = {
name: "off"
for name, (type, _) in self._modules.items()
if type == "flip-flop"
}
conjunction_states: dict[str, dict[str, PulseType]] = {
name: {
input: "low"
for input, (_, outputs) in self._modules.items()
if name in outputs
}
for name, (type, _) in self._modules.items()
if type == "conjunction"
}
counts: dict[PulseType, int] = {"low": 0, "high": 0}
for _ in range(1000):
result, _ = self._process(
("button", "broadcaster", "low"), flip_flop_states, conjunction_states
)
for pulse in ("low", "high"):
counts[pulse] += result[pulse]
yield counts["low"] * counts["high"]
# part 2
# reset states
for name in flip_flop_states:
flip_flop_states[name] = "off"
for name in conjunction_states:
for input in conjunction_states[name]:
conjunction_states[name][input] = "low"
# find the conjunction connected to rx
to_rx = [
name for name, (_, outputs) in self._modules.items() if "rx" in outputs
]
assert len(to_rx) == 1, "cannot handle multiple module inputs for rx"
assert (
self._modules[to_rx[0]][0] == "conjunction"
), "can only handle conjunction as input to rx"
to_rx_inputs = [
name for name, (_, outputs) in self._modules.items() if to_rx[0] in outputs
]
assert all(
self._modules[i][0] == "conjunction" and len(self._modules[i][1]) == 1
for i in to_rx_inputs
), "can only handle inversion as second-order inputs to rx"
count = 1
cycles: dict[str, int] = {}
second: dict[str, int] = {}
while len(second) != len(to_rx_inputs):
_, inputs = self._process(
("button", "broadcaster", "low"), flip_flop_states, conjunction_states
)
for node in to_rx_inputs:
if inputs[node]["low"] == 1:
if node not in cycles:
cycles[node] = count
elif node not in second:
second[node] = count
count += 1
assert all(
second[k] == cycles[k] * 2 for k in to_rx_inputs
), "cannot only handle cycles starting at the beginning"
yield lcm(*cycles.values())