advent-of-code/src/holt59/aoc/2024/day17.py
2024-12-17 15:20:30 +01:00

131 lines
3.9 KiB
Python

from typing import Any, Iterator
from ..base import BaseSolver
def combo(registers: dict[str, int], operand: int):
if operand < 4:
return operand
assert operand < 7
return registers["ABC"[operand - 4]]
def adv(registers: dict[str, int], operand: int) -> int | None:
registers["A"] = registers["A"] >> combo(registers, operand)
def bxl(registers: dict[str, int], operand: int) -> int | None:
registers["B"] ^= operand
def bst(registers: dict[str, int], operand: int) -> int | None:
registers["B"] = combo(registers, operand) % 8
def jnz(registers: dict[str, int], operand: int) -> int | None:
if registers["A"] != 0:
return operand
def bxc(registers: dict[str, int], operand: int) -> int | None:
registers["B"] = registers["B"] ^ registers["C"]
def bdv(registers: dict[str, int], operand: int) -> int | None:
registers["B"] = registers["A"] >> combo(registers, operand)
def cdv(registers: dict[str, int], operand: int) -> int | None:
registers["C"] = registers["A"] >> combo(registers, operand)
def run(registers: dict[str, int], program: list[int]):
outputs: list[int] = []
def out(registers: dict[str, int], operand: int) -> int | None:
outputs.append(combo(registers, operand) % 8)
instructions = [adv, bxl, bst, jnz, bxc, out, bdv, cdv]
index = 0
while index < len(program):
instruction, operand = instructions[program[index]], program[index + 1]
ret = instruction(registers, operand)
if ret is None:
index += 2
else:
index = ret
return outputs
class Solver(BaseSolver):
def solve(self, input: str) -> Iterator[Any]:
register_s, program_s = input.split("\n\n")
registers = {
p[0][-1]: int(p[1].strip())
for line in register_s.splitlines()
if (p := line.split(":"))
}
program = [int(c) for c in program_s.split(":")[1].strip().split(",")]
self.logger.info(f"program ({len(program)}): " + ",".join(map(str, program)))
instruction_s = [
"A = A >> {}",
"B = B ^ {}",
"B = {} % 8",
"JMP {}",
"B = B ^ C",
"OUT {} % 8",
"B = A >> {}",
"C = A >> {}",
]
self.logger.info("PROGRAM:")
for index in range(0, len(program), 2):
self.logger.info(
instruction_s[program[index]].format(
""
if program[index] == 4
else (
program[index + 1]
if program[index] in (1, 3) or program[index + 1] < 4
else "ABC"[program[index + 1] - 4]
)
),
)
yield ",".join(map(str, run(registers.copy(), program)))
# last instruction is JNZ 0 (jump at the beginning), and it is the only jump
# in the program
jnz_indices = [i for i in range(0, len(program), 2) if program[i] == 3]
assert jnz_indices == [len(program) - 2] and program[-1] == 0
# previous instruction is dividing A by 8, or A = A >> 3
assert program[-4:-2] == [0, 3]
# previous instruction is a OUT B % 8, and it is the only OUT in the program
out_indices = [i for i in range(0, len(program), 2) if program[i] == 5]
assert out_indices == [len(program) - 6] and program[len(program) - 5] == 5
valid: list[int] = [0]
for p in reversed(program):
new_valid: list[int] = []
for v in valid:
a_high = v << 3
for a_low in range(0, 2**3):
registers["A"] = a_high | a_low
run(registers, program[:-6])
if registers["B"] % 8 == p:
new_valid.append(a_high | a_low)
valid = new_valid
assert run(registers | {"A": min(valid)}, program) == program
yield min(valid)