import re import sys from typing import Iterator def extract_multiply(line: str) -> Iterator[int]: for m in re.finditer(r"mul\(([0-9]{1,3}),\s*([0-9]{1,3})\)", line): yield int(m.group(1)) * int(m.group(2)) def valid_memory_blocks(line: str) -> Iterator[str]: accumulate = True while line: if accumulate: if (dont_i := line.find("don't()")) != -1: yield line[:dont_i] line, accumulate = line[dont_i:], False else: yield line line = "" else: if (do_i := line.find("do()")) != -1: line, accumulate = line[do_i:], True else: line = "" line = sys.stdin.read().strip() answer_1 = sum(extract_multiply(line)) answer_2 = sum(sum(extract_multiply(block)) for block in valid_memory_blocks(line)) print(f"answer 1 is {answer_1}") print(f"answer 2 is {answer_2}")