diff --git a/src/holt59/aoc/2024/day3.py b/src/holt59/aoc/2024/day3.py index 9ec9139..09afee0 100644 --- a/src/holt59/aoc/2024/day3.py +++ b/src/holt59/aoc/2024/day3.py @@ -4,19 +4,19 @@ from typing import Iterator def extract_multiply(line: str) -> Iterator[int]: - for m in re.findall(r"mul\(([0-9]{1,3}),\s*([0-9]{1,3})\)", line): - yield int(m[0]) * int(m[1]) + for m in re.finditer(r"mul\(([0-9]{1,3}),\s*([0-9]{1,3})\)", line): + yield int(m.group(1)) * int(m.group(2)) -def remove_disabled_parts(line: str) -> str: - fixed_line, accumulate = "", True +def valid_memory_blocks(line: str) -> Iterator[str]: + accumulate = True while line: if accumulate: if (dont_i := line.find("don't()")) != -1: - fixed_line += line[:dont_i] + yield line[:dont_i] line, accumulate = line[dont_i:], False else: - fixed_line += line + yield line line = "" else: if (do_i := line.find("do()")) != -1: @@ -24,13 +24,11 @@ def remove_disabled_parts(line: str) -> str: else: line = "" - return fixed_line - line = sys.stdin.read().strip() answer_1 = sum(extract_multiply(line)) -answer_2 = sum(extract_multiply(remove_disabled_parts(line))) +answer_2 = sum(sum(extract_multiply(block)) for block in valid_memory_blocks(line)) print(f"answer 1 is {answer_1}") print(f"answer 2 is {answer_2}")