from typing import Any, Iterator, Literal, TypeAlias, cast from ..base import BaseSolver CellType: TypeAlias = Literal[".", "|", "-", "\\", "/"] Direction: TypeAlias = Literal["R", "L", "U", "D"] Mappings: dict[ CellType, dict[ Direction, tuple[tuple[tuple[int, int, Direction], ...], tuple[Direction, ...]], ], ] = { ".": { "R": (((0, +1, "R"),), ("R", "L")), "L": (((0, -1, "L"),), ("R", "L")), "U": (((-1, 0, "U"),), ("U", "D")), "D": (((+1, 0, "D"),), ("U", "D")), }, "-": { "R": (((0, +1, "R"),), ("R", "L")), "L": (((0, -1, "L"),), ("R", "L")), "U": (((0, +1, "R"), (0, -1, "L")), ("U", "D")), "D": (((0, +1, "R"), (0, -1, "L")), ("U", "D")), }, "|": { "U": (((-1, 0, "U"),), ("U", "D")), "D": (((+1, 0, "D"),), ("U", "D")), "R": (((-1, 0, "U"), (+1, 0, "D")), ("R", "L")), "L": (((-1, 0, "U"), (+1, 0, "D")), ("R", "L")), }, "/": { "R": (((-1, 0, "U"),), ("R", "D")), "L": (((+1, 0, "D"),), ("L", "U")), "U": (((0, +1, "R"),), ("U", "L")), "D": (((0, -1, "L"),), ("R", "D")), }, "\\": { "R": (((+1, 0, "D"),), ("R", "U")), "L": (((-1, 0, "U"),), ("L", "D")), "U": (((0, -1, "L"),), ("U", "R")), "D": (((0, +1, "R"),), ("L", "D")), }, } def propagate( layout: list[list[CellType]], start: tuple[int, int], direction: Direction ) -> list[list[tuple[Direction, ...]]]: n_rows, n_cols = len(layout), len(layout[0]) beams: list[list[tuple[Direction, ...]]] = [ [() for _ in range(len(layout[0]))] for _ in range(len(layout)) ] queue: list[tuple[tuple[int, int], Direction]] = [(start, direction)] while queue: (row, col), direction = queue.pop() if ( row not in range(0, n_rows) or col not in range(0, n_cols) or direction in beams[row][col] ): continue moves, update = Mappings[layout[row][col]][direction] beams[row][col] += update for move in moves: queue.append(((row + move[0], col + move[1]), move[2])) return beams class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: layout: list[list[CellType]] = [ [cast(CellType, col) for col in row] for row in input.splitlines() ] beams = propagate(layout, (0, 0), "R") if self.files: self.files.create( "beams.txt", "\n".join( "".join("#" if col else "." for col in row) for row in beams ).encode(), True, ) # part 1 yield sum(sum(map(bool, row)) for row in beams) # part 2 n_rows, n_cols = len(layout), len(layout[0]) cases: list[tuple[tuple[int, int], Direction]] = [] for row in range(n_rows): cases.append(((row, 0), "R")) cases.append(((row, n_cols - 1), "L")) for col in range(n_cols): cases.append(((0, col), "D")) cases.append(((n_rows - 1, col), "U")) yield max( sum(sum(map(bool, row)) for row in propagate(layout, start, direction)) for start, direction in cases )