import itertools as it from typing import Any, Iterator from ..base import BaseSolver class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: lines = input.splitlines() n = len(lines) yield sum( line.count("XMAS") + line.count("SAMX") for i in range(n) for ri, rk, ro, ci, ck, cm in ( (1, 0, 0, 0, 1, n), (0, 1, 0, 1, 0, n), (0, 1, 0, 1, 1, n - i), (0, -1, -1, 1, 1, n - i), (1, 1, 0, 0, 1, n - i if i != 0 else 0), (-1, -1, -1, 0, 1, n - i if i != 0 else 0), ) if ( line := "".join( lines[ri * i + rk * k + ro][ci * i + ck * k] for k in range(cm) ) ) ) yield sum( lines[i][j] == "A" and "".join( lines[i + di][j + dj] for di, dj in it.product((-1, 1), (-1, 1)) ) in {"MSMS", "SSMM", "MMSS", "SMSM"} for i, j in it.product(range(1, n - 1), range(1, n - 1)) )