131 lines
3.9 KiB
Python
131 lines
3.9 KiB
Python
from typing import Any, Iterator
|
|
|
|
from ..base import BaseSolver
|
|
|
|
|
|
def combo(registers: dict[str, int], operand: int):
|
|
if operand < 4:
|
|
return operand
|
|
assert operand < 7
|
|
return registers["ABC"[operand - 4]]
|
|
|
|
|
|
def adv(registers: dict[str, int], operand: int) -> int | None:
|
|
registers["A"] = registers["A"] >> combo(registers, operand)
|
|
|
|
|
|
def bxl(registers: dict[str, int], operand: int) -> int | None:
|
|
registers["B"] ^= operand
|
|
|
|
|
|
def bst(registers: dict[str, int], operand: int) -> int | None:
|
|
registers["B"] = combo(registers, operand) % 8
|
|
|
|
|
|
def jnz(registers: dict[str, int], operand: int) -> int | None:
|
|
if registers["A"] != 0:
|
|
return operand
|
|
|
|
|
|
def bxc(registers: dict[str, int], operand: int) -> int | None:
|
|
registers["B"] = registers["B"] ^ registers["C"]
|
|
|
|
|
|
def bdv(registers: dict[str, int], operand: int) -> int | None:
|
|
registers["B"] = registers["A"] >> combo(registers, operand)
|
|
|
|
|
|
def cdv(registers: dict[str, int], operand: int) -> int | None:
|
|
registers["C"] = registers["A"] >> combo(registers, operand)
|
|
|
|
|
|
def run(registers: dict[str, int], program: list[int]):
|
|
outputs: list[int] = []
|
|
|
|
def out(registers: dict[str, int], operand: int) -> int | None:
|
|
outputs.append(combo(registers, operand) % 8)
|
|
|
|
instructions = [adv, bxl, bst, jnz, bxc, out, bdv, cdv]
|
|
|
|
index = 0
|
|
while index < len(program):
|
|
instruction, operand = instructions[program[index]], program[index + 1]
|
|
|
|
ret = instruction(registers, operand)
|
|
|
|
if ret is None:
|
|
index += 2
|
|
else:
|
|
index = ret
|
|
|
|
return outputs
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
def solve(self, input: str) -> Iterator[Any]:
|
|
register_s, program_s = input.split("\n\n")
|
|
|
|
registers = {
|
|
p[0][-1]: int(p[1].strip())
|
|
for line in register_s.splitlines()
|
|
if (p := line.split(":"))
|
|
}
|
|
program = [int(c) for c in program_s.split(":")[1].strip().split(",")]
|
|
|
|
self.logger.info(f"program ({len(program)}): " + ",".join(map(str, program)))
|
|
|
|
instruction_s = [
|
|
"A = A >> {}",
|
|
"B = B ^ {}",
|
|
"B = {} % 8",
|
|
"JMP {}",
|
|
"B = B ^ C",
|
|
"OUT {} % 8",
|
|
"B = A >> {}",
|
|
"C = A >> {}",
|
|
]
|
|
|
|
self.logger.info("PROGRAM:")
|
|
for index in range(0, len(program), 2):
|
|
self.logger.info(
|
|
instruction_s[program[index]].format(
|
|
""
|
|
if program[index] == 4
|
|
else (
|
|
program[index + 1]
|
|
if program[index] in (1, 3) or program[index + 1] < 4
|
|
else "ABC"[program[index + 1] - 4]
|
|
)
|
|
),
|
|
)
|
|
|
|
yield ",".join(map(str, run(registers.copy(), program)))
|
|
|
|
# last instruction is JNZ 0 (jump at the beginning), and it is the only jump
|
|
# in the program
|
|
jnz_indices = [i for i in range(0, len(program), 2) if program[i] == 3]
|
|
assert jnz_indices == [len(program) - 2] and program[-1] == 0
|
|
|
|
# previous instruction is dividing A by 8, or A = A >> 3
|
|
assert program[-4:-2] == [0, 3]
|
|
|
|
# previous instruction is a OUT B % 8, and it is the only OUT in the program
|
|
out_indices = [i for i in range(0, len(program), 2) if program[i] == 5]
|
|
assert out_indices == [len(program) - 6] and program[len(program) - 5] == 5
|
|
|
|
valid: list[int] = [0]
|
|
for p in reversed(program):
|
|
new_valid: list[int] = []
|
|
for v in valid:
|
|
a_high = v << 3
|
|
for a_low in range(0, 2**3):
|
|
registers["A"] = a_high | a_low
|
|
run(registers, program[:-6])
|
|
if registers["B"] % 8 == p:
|
|
new_valid.append(a_high | a_low)
|
|
valid = new_valid
|
|
|
|
assert run(registers | {"A": min(valid)}, program) == program
|
|
|
|
yield min(valid)
|