Refactor code for API (#3)

Co-authored-by: Mikael CAPELLE <mikael.capelle@thalesaleniaspace.com>
Co-authored-by: Mikaël Capelle <capelle.mikael@gmail.com>
Reviewed-on: #3
This commit is contained in:
2024-12-08 13:06:41 +00:00
parent ab4e3e199c
commit ce315b8778
130 changed files with 4599 additions and 3336 deletions

View File

@@ -1,53 +1,53 @@
import string
import sys
from collections import defaultdict
from typing import Any, Iterator
from ..base import BaseSolver
NOT_A_SYMBOL = "." + string.digits
lines = sys.stdin.read().splitlines()
values: list[int] = []
gears: dict[tuple[int, int], list[int]] = defaultdict(list)
class Solver(BaseSolver):
def solve(self, input: str) -> Iterator[Any]:
lines = input.splitlines()
for i, line in enumerate(lines):
j = 0
while j < len(line):
# skip everything until a digit is found (start of a number)
if line[j] not in string.digits:
j += 1
continue
values: list[int] = []
gears: dict[tuple[int, int], list[int]] = defaultdict(list)
# extract the range of the number and its value
k = j + 1
while k < len(line) and line[k] in string.digits:
k += 1
for i, line in enumerate(lines):
j = 0
while j < len(line):
# skip everything until a digit is found (start of a number)
if line[j] not in string.digits:
j += 1
continue
value = int(line[j:k])
# extract the range of the number and its value
k = j + 1
while k < len(line) and line[k] in string.digits:
k += 1
# lookup around the number if there is a symbol - we go through the number
# itself but that should not matter since it only contains digits
found = False
for i2 in range(max(0, i - 1), min(i + 1, len(lines) - 1) + 1):
for j2 in range(max(0, j - 1), min(k, len(line) - 1) + 1):
assert i2 >= 0 and i2 < len(lines)
assert j2 >= 0 and j2 < len(line)
value = int(line[j:k])
if lines[i2][j2] not in NOT_A_SYMBOL:
found = True
# lookup around the number if there is a symbol - we go through the number
# itself but that should not matter since it only contains digits
found = False
for i2 in range(max(0, i - 1), min(i + 1, len(lines) - 1) + 1):
for j2 in range(max(0, j - 1), min(k, len(line) - 1) + 1):
assert i2 >= 0 and i2 < len(lines)
assert j2 >= 0 and j2 < len(line)
if lines[i2][j2] == "*":
gears[i2, j2].append(value)
if lines[i2][j2] not in NOT_A_SYMBOL:
found = True
if found:
values.append(value)
if lines[i2][j2] == "*":
gears[i2, j2].append(value)
# continue starting from the end of the number
j = k
if found:
values.append(value)
# part 1
answer_1 = sum(values)
print(f"answer 1 is {answer_1}")
# continue starting from the end of the number
j = k
# part 2
answer_2 = sum(v1 * v2 for v1, v2 in filter(lambda vs: len(vs) == 2, gears.values()))
print(f"answer 2 is {answer_2}")
yield sum(values)
yield sum(v1 * v2 for v1, v2 in filter(lambda vs: len(vs) == 2, gears.values()))