import itertools as it from collections import defaultdict from typing import Any, Iterator, cast from ..base import BaseSolver def compute_antinodes( a1: tuple[int, int], a2: tuple[int, int], n_rows: int, n_cols: int, min_distance: int = 1, max_distance: int | None = 1, ): if a1[0] > a2[0]: a1, a2 = a2, a1 d_row, d_col = a2[0] - a1[0], a2[1] - a1[1] points: list[tuple[int, int]] = [] for c in range(min_distance, (max_distance or n_rows) + 1): row_1, col_1 = a1[0] - c * d_row, a1[1] - c * d_col row_2, col_2 = a2[0] + c * d_row, a2[1] + c * d_col valid_1, valid_2 = ( 0 <= row_1 < n_rows and 0 <= col_1 < n_cols, 0 <= row_2 < n_rows and 0 <= col_2 < n_cols, ) if not valid_1 and not valid_2: break if valid_1: points.append((row_1, col_1)) if valid_2: points.append((row_2, col_2)) return tuple(points) class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: lines = input.splitlines() n_rows, n_cols = len(lines), len(lines[0]) antennas: dict[str, list[tuple[int, int]]] = defaultdict(list) for i, j in it.product(range(n_rows), range(n_cols)): if lines[i][j] != ".": antennas[lines[i][j]].append((i, j)) yield len( cast(set[tuple[int, int]], set()).union( it.chain( *( compute_antinodes(a1, a2, n_rows, n_cols) for antennas_of_frequency in antennas.values() for a1, a2 in it.permutations(antennas_of_frequency, 2) ) ) ) ) yield len( cast(set[tuple[int, int]], set()).union( it.chain( *( compute_antinodes(a1, a2, n_rows, n_cols, 0, None) for antennas_of_frequency in antennas.values() for a1, a2 in it.permutations(antennas_of_frequency, 2) ) ) ) )