Refactor code for API (#3)

Co-authored-by: Mikael CAPELLE <mikael.capelle@thalesaleniaspace.com>
Co-authored-by: Mikaël Capelle <capelle.mikael@gmail.com>
Reviewed-on: #3
This commit is contained in:
2024-12-08 13:06:41 +00:00
parent ab4e3e199c
commit ce315b8778
130 changed files with 4599 additions and 3336 deletions

View File

@@ -1,18 +1,18 @@
import sys
from math import prod
from typing import Any, Iterator
values = [[int(c) for c in row] for row in sys.stdin.read().splitlines()]
n_rows, n_cols = len(values), len(values[0])
from ..base import BaseSolver
def neighbors(point: tuple[int, int]):
def neighbors(point: tuple[int, int], n_rows: int, n_cols: int):
i, j = point
for di, dj in ((-1, 0), (+1, 0), (0, -1), (0, +1)):
if 0 <= i + di < n_rows and 0 <= j + dj < n_cols:
yield (i + di, j + dj)
def basin(start: tuple[int, int]) -> set[tuple[int, int]]:
def basin(values: list[list[int]], start: tuple[int, int]) -> set[tuple[int, int]]:
n_rows, n_cols = len(values), len(values[0])
visited: set[tuple[int, int]] = set()
queue = [start]
@@ -23,22 +23,25 @@ def basin(start: tuple[int, int]) -> set[tuple[int, int]]:
continue
visited.add((i, j))
queue.extend(neighbors((i, j)))
queue.extend(neighbors((i, j), n_rows, n_cols))
return visited
low_points = [
(i, j)
for i in range(n_rows)
for j in range(n_cols)
if all(values[ti][tj] > values[i][j] for ti, tj in neighbors((i, j)))
]
class Solver(BaseSolver):
def solve(self, input: str) -> Iterator[Any]:
values = [[int(c) for c in row] for row in input.splitlines()]
n_rows, n_cols = len(values), len(values[0])
# part 1
answer_1 = sum(values[i][j] + 1 for i, j in low_points)
print(f"answer 1 is {answer_1}")
low_points = [
(i, j)
for i in range(n_rows)
for j in range(n_cols)
if all(
values[ti][tj] > values[i][j]
for ti, tj in neighbors((i, j), n_rows, n_cols)
)
]
# part 2
answer_2 = prod(sorted(len(basin(point)) for point in low_points)[-3:])
print(f"answer 2 is {answer_2}")
yield sum(values[i][j] + 1 for i, j in low_points)
yield prod(sorted(len(basin(values, point)) for point in low_points)[-3:])