Refactor code for API #3

Merged
mikael.capelle merged 13 commits from dev/refactor-for-ui into master 2024-12-08 13:06:42 +00:00
2 changed files with 46 additions and 57 deletions
Showing only changes of commit 8c707c00ba - Show all commits

View File

@ -1,5 +1,7 @@
import sys
from collections import defaultdict
from typing import Any, Iterator
from ..base import BaseSolver
def in_correct_order(update: list[int], requirements: dict[int, set[int]]) -> bool:
@ -39,7 +41,9 @@ def to_correct_order(
return update
part1, part2 = sys.stdin.read().strip().split("\n\n")
class Solver(BaseSolver):
def solve(self, input: str) -> Iterator[Any]:
part1, part2 = input.split("\n\n")
requirements: dict[int, set[int]] = defaultdict(set)
for line in part1.splitlines():
@ -48,17 +52,14 @@ for line in part1.splitlines():
updates = [list(map(int, line.split(","))) for line in part2.splitlines()]
answer_1 = sum(
yield sum(
update[len(update) // 2]
for update in updates
if in_correct_order(update, requirements)
)
answer_2 = sum(
yield sum(
to_correct_order(update, requirements, len(update) // 2 + 1)[-1]
for update in updates
if not in_correct_order(update, requirements)
)
print(f"answer 1 is {answer_1}")
print(f"answer 2 is {answer_2}")

View File

@ -1,6 +1,7 @@
import itertools as it
import sys
from typing import TypeAlias
from typing import Any, Iterator, TypeAlias
from ..base import BaseSolver
NodeType: TypeAlias = tuple[tuple[int, int], tuple[int, int]]
EdgesType: TypeAlias = dict[NodeType, tuple[NodeType, set[tuple[int, int]]]]
@ -91,28 +92,17 @@ def is_loop(lines: list[str], edges: EdgesType, position: tuple[int, int]):
return current_node in found
def print_grid(
lines: list[str], marked: set[tuple[int, int]], current_pos: tuple[int, int] | None
):
chars = list(map(list, lines))
for i, j in marked:
chars[i][j] = "X"
if current_pos:
chars[current_pos[0]][current_pos[1]] = "T"
for line in chars:
print("".join(line))
print()
class Solver(BaseSolver):
def solve(self, input: str) -> Iterator[Any]:
# read lines
lines = sys.stdin.read().splitlines()
lines = input.splitlines()
# find and delete original position
start_pos = next(
(i, j) for i, row in enumerate(lines) for j, col in enumerate(row) if col == "^"
(i, j)
for i, row in enumerate(lines)
for j, col in enumerate(row)
if col == "^"
)
lines[start_pos[0]] = lines[start_pos[0]].replace("^", ".")
@ -127,8 +117,6 @@ while current_node[0] != FINAL_POS:
current_node, current_marked = edges[current_node]
marked = marked.union(current_marked)
answer_1 = len(marked)
print(f"answer 1 is {answer_1}")
yield len(marked)
answer_2 = sum(is_loop(lines, edges, pos) for pos in marked if pos != start_pos)
print(f"answer 2 is {answer_2}")
yield sum(is_loop(lines, edges, pos) for pos in marked if pos != start_pos)