from typing import Any, Iterator from ..base import BaseSolver def combo(registers: dict[str, int], operand: int): if operand < 4: return operand assert operand < 7 return registers["ABC"[operand - 4]] def adv(registers: dict[str, int], operand: int) -> int | None: registers["A"] = registers["A"] // (2 ** combo(registers, operand)) def bxl(registers: dict[str, int], operand: int) -> int | None: registers["B"] ^= operand def bst(registers: dict[str, int], operand: int) -> int | None: registers["B"] = combo(registers, operand) % 8 def jnz(registers: dict[str, int], operand: int) -> int | None: if registers["A"] != 0: return operand def bxc(registers: dict[str, int], operand: int) -> int | None: registers["B"] = registers["B"] ^ registers["C"] def bdv(registers: dict[str, int], operand: int) -> int | None: registers["B"] = registers["A"] // (2 ** combo(registers, operand)) def cdv(registers: dict[str, int], operand: int) -> int | None: registers["C"] = registers["A"] // (2 ** combo(registers, operand)) def run(registers: dict[str, int], program: list[int]): outputs: list[int] = [] def out(registers: dict[str, int], operand: int) -> int | None: outputs.append(combo(registers, operand) % 8) instructions = [adv, bxl, bst, jnz, bxc, out, bdv, cdv] u = { adv: "A = A // (2 ** {})", bxl: "B = B ^ {}", bst: "B = {} % 8", jnz: "JMP {}", bxc: "B = B ^ C", out: "OUT {}", bdv: "B = A // (2 ** {})", cdv: "C = A // (2 ** {})", } for index in range(0, len(program), 2): print( u[instructions[program[index]]].format( "" if program[index] == 4 else ( program[index + 1] if program[index] in (1, 3) or program[index + 1] < 4 else "ABC"[program[index + 1] - 4] ) ), ) index = 0 while index < len(program): instruction, operand = instructions[program[index]], program[index + 1] ret = instruction(registers, operand) if ret is None: index += 2 else: index = ret return outputs class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: register_s, program_s = input.split("\n\n") registers = { p[0][-1]: int(p[1].strip()) for line in register_s.splitlines() if (p := line.split(":")) } program = [int(c) for c in program_s.split(":")[1].strip().split(",")] print(f"program ({len(program)}):", ",".join(map(str, program))) yield ",".join(map(str, run(registers.copy(), program))) out_index = next(i for i in range(0, len(program), 2) if program[i] == 5) out_register = "ABC"[program[out_index + 1] - 4] # there is only one jump instruction and it goes back to the beginning of the # program jnz_indices = [i for i in range(0, len(program), 2) if program[i] == 3] assert ( len(jnz_indices) == 1 and jnz_indices[0] == len(program) - 2 and program[-1] == 0 ) for value in program: print(value, f"{out_register} must equal {value | 4}")