import itertools as it from typing import Any, Iterator from ..base import BaseSolver def do_step(values: list[list[int]]) -> tuple[list[list[int]], set[tuple[int, int]]]: values = [[c + 1 for c in r] for r in values] flashed: set[tuple[int, int]] = set() while True: found = False for i_row, row in enumerate(values): for i_col, col in enumerate(row): if col <= 9 or (i_row, i_col) in flashed: continue found = True flashed.add((i_row, i_col)) for dr, dc in it.product((-1, 0, 1), repeat=2): if 0 <= i_row + dr < len(values) and 0 <= i_col + dc < len( values[0] ): values[i_row + dr][i_col + dc] += 1 if not found: break for i, j in flashed: values[i][j] = 0 return values, flashed class Solver(BaseSolver): def print_grid(self, values: list[list[int]], flashed: set[tuple[int, int]]): for i_row, row in enumerate(values): s_row = "" for i_col, col in enumerate(row): if (i_row, i_col) in flashed: s_row += f"\033[0;31m{col}\033[0;00m" else: s_row += str(col) self.logger.info(s_row) self.logger.info("") def solve(self, input: str) -> Iterator[Any]: values_0 = [[int(c) for c in r] for r in input.splitlines()] values = values_0 total_flashed: int = 0 for _ in range(100): values, flashed = do_step(values) total_flashed += len(flashed) yield total_flashed n_cells = len(values) * len(values[0]) flashed: set[tuple[int, int]] = set() values, step = values_0, 0 while len(flashed) != n_cells: values, flashed = do_step(values) step += 1 yield step