from typing import Any, Iterator from ..base import BaseSolver def combo(registers: dict[str, int], operand: int): if operand < 4: return operand assert operand < 7 return registers["ABC"[operand - 4]] def adv(registers: dict[str, int], operand: int) -> int | None: registers["A"] = registers["A"] >> combo(registers, operand) def bxl(registers: dict[str, int], operand: int) -> int | None: registers["B"] ^= operand def bst(registers: dict[str, int], operand: int) -> int | None: registers["B"] = combo(registers, operand) % 8 def jnz(registers: dict[str, int], operand: int) -> int | None: if registers["A"] != 0: return operand def bxc(registers: dict[str, int], operand: int) -> int | None: registers["B"] = registers["B"] ^ registers["C"] def bdv(registers: dict[str, int], operand: int) -> int | None: registers["B"] = registers["A"] >> combo(registers, operand) def cdv(registers: dict[str, int], operand: int) -> int | None: registers["C"] = registers["A"] >> combo(registers, operand) def run(registers: dict[str, int], program: list[int]): outputs: list[int] = [] def out(registers: dict[str, int], operand: int) -> int | None: outputs.append(combo(registers, operand) % 8) instructions = [adv, bxl, bst, jnz, bxc, out, bdv, cdv] index = 0 while index < len(program): instruction, operand = instructions[program[index]], program[index + 1] ret = instruction(registers, operand) if ret is None: index += 2 else: index = ret return outputs class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: register_s, program_s = input.split("\n\n") registers = { p[0][-1]: int(p[1].strip()) for line in register_s.splitlines() if (p := line.split(":")) } program = [int(c) for c in program_s.split(":")[1].strip().split(",")] self.logger.info(f"program ({len(program)}): " + ",".join(map(str, program))) instruction_s = [ "A = A >> {}", "B = B ^ {}", "B = {} % 8", "JMP {}", "B = B ^ C", "OUT {} % 8", "B = A >> {}", "C = A >> {}", ] self.logger.info("PROGRAM:") for index in range(0, len(program), 2): self.logger.info( instruction_s[program[index]].format( "" if program[index] == 4 else ( program[index + 1] if program[index] in (1, 3) or program[index + 1] < 4 else "ABC"[program[index + 1] - 4] ) ), ) yield ",".join(map(str, run(registers.copy(), program))) # last instruction is JNZ 0 (jump at the beginning), and it is the only jump # in the program jnz_indices = [i for i in range(0, len(program), 2) if program[i] == 3] assert jnz_indices == [len(program) - 2] and program[-1] == 0 # previous instruction is dividing A by 8, or A = A >> 3 assert program[-4:-2] == [0, 3] # previous instruction is a OUT B % 8, and it is the only OUT in the program out_indices = [i for i in range(0, len(program), 2) if program[i] == 5] assert out_indices == [len(program) - 6] and program[len(program) - 5] == 5 valid: list[int] = [0] for p in reversed(program): new_valid: list[int] = [] for v in valid: a_high = v << 3 for a_low in range(0, 2**3): registers["A"] = a_high | a_low run(registers, program[:-6]) if registers["B"] % 8 == p: new_valid.append(a_high | a_low) valid = new_valid assert run(registers | {"A": min(valid)}, program) == program yield min(valid)