import re import sys from typing import Iterator def extract_multiply(line: str) -> Iterator[int]: for m in re.findall(r"mul\(([0-9]{1,3}),\s*([0-9]{1,3})\)", line): yield int(m[0]) * int(m[1]) def remove_disabled_parts(line: str) -> str: fixed_line, accumulate = "", True while line: if accumulate: if (dont_i := line.find("don't()")) != -1: fixed_line += line[:dont_i] line, accumulate = line[dont_i:], False else: fixed_line += line line = "" else: if (do_i := line.find("do()")) != -1: line, accumulate = line[do_i:], True else: line = "" return fixed_line line = sys.stdin.read().strip() answer_1 = sum(extract_multiply(line)) answer_2 = sum(extract_multiply(remove_disabled_parts(line))) print(f"answer 1 is {answer_1}") print(f"answer 2 is {answer_2}")