import heapq import math from collections import defaultdict from typing import Any, Iterator from ..base import BaseSolver class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: lines = [line.strip() for line in input.splitlines()] winds = { (i - 1, j - 1, lines[i][j]) for i in range(1, len(lines) - 1) for j in range(1, len(lines[i]) - 1) if lines[i][j] != "." } n_rows, n_cols = len(lines) - 2, len(lines[0]) - 2 CYCLE = math.lcm(n_rows, n_cols) east_winds = [ {j for j in range(n_cols) if (i, j, ">") in winds} for i in range(n_rows) ] west_winds = [ {j for j in range(n_cols) if (i, j, "<") in winds} for i in range(n_rows) ] north_winds = [ {i for i in range(n_rows) if (i, j, "^") in winds} for j in range(n_cols) ] south_winds = [ {i for i in range(n_rows) if (i, j, "v") in winds} for j in range(n_cols) ] def run(start: tuple[int, int], start_cycle: int, end: tuple[int, int]): def heuristic(y: int, x: int) -> int: return abs(end[0] - y) + abs(end[1] - x) # (distance + heuristic, distance, (start_pos, cycle)) queue = [ (heuristic(start[0], start[1]), 0, ((start[0], start[1]), start_cycle)) ] visited: set[tuple[tuple[int, int], int]] = set() distances: dict[tuple[int, int], dict[int, int]] = defaultdict(lambda: {}) while queue: _, distance, ((y, x), cycle) = heapq.heappop(queue) if ((y, x), cycle) in visited: continue distances[y, x][cycle] = distance visited.add(((y, x), cycle)) if (y, x) == (end[0], end[1]): break for dy, dx in (0, 0), (-1, 0), (1, 0), (0, -1), (0, 1): ty = y + dy tx = x + dx n_cycle = (cycle + 1) % CYCLE if (ty, tx) == end: heapq.heappush( queue, (distance + 1, distance + 1, ((ty, tx), n_cycle)) ) break if ((ty, tx), n_cycle) in visited: continue if (ty, tx) != start and ( ty < 0 or tx < 0 or ty >= n_rows or tx >= n_cols ): continue if (ty, tx) != start: if (ty - n_cycle) % n_rows in south_winds[tx]: continue if (ty + n_cycle) % n_rows in north_winds[tx]: continue if (tx + n_cycle) % n_cols in west_winds[ty]: continue if (tx - n_cycle) % n_cols in east_winds[ty]: continue heapq.heappush( queue, ( ( heuristic(ty, tx) + distance + 1, distance + 1, ((ty, tx), n_cycle), ) ), ) return distances, next(iter(distances[end].values())) start = ( -1, next(j for j in range(1, len(lines[0]) - 1) if lines[0][j] == ".") - 1, ) end = ( n_rows, next(j for j in range(1, len(lines[-1]) - 1) if lines[-1][j] == ".") - 1, ) distances_1, forward_1 = run(start, 0, end) yield forward_1 distances_2, return_1 = run(end, next(iter(distances_1[end].keys())), start) _distances_3, forward_2 = run(start, next(iter(distances_2[start].keys())), end) yield forward_1 + return_1 + forward_2