118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
import heapq
|
|
import math
|
|
from collections import defaultdict
|
|
from typing import Any, Iterator
|
|
|
|
from ..base import BaseSolver
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
def solve(self, input: str) -> Iterator[Any]:
|
|
lines = [line.strip() for line in input.splitlines()]
|
|
|
|
winds = {
|
|
(i - 1, j - 1, lines[i][j])
|
|
for i in range(1, len(lines) - 1)
|
|
for j in range(1, len(lines[i]) - 1)
|
|
if lines[i][j] != "."
|
|
}
|
|
|
|
n_rows, n_cols = len(lines) - 2, len(lines[0]) - 2
|
|
CYCLE = math.lcm(n_rows, n_cols)
|
|
|
|
east_winds = [
|
|
{j for j in range(n_cols) if (i, j, ">") in winds} for i in range(n_rows)
|
|
]
|
|
west_winds = [
|
|
{j for j in range(n_cols) if (i, j, "<") in winds} for i in range(n_rows)
|
|
]
|
|
north_winds = [
|
|
{i for i in range(n_rows) if (i, j, "^") in winds} for j in range(n_cols)
|
|
]
|
|
south_winds = [
|
|
{i for i in range(n_rows) if (i, j, "v") in winds} for j in range(n_cols)
|
|
]
|
|
|
|
def run(start: tuple[int, int], start_cycle: int, end: tuple[int, int]):
|
|
def heuristic(y: int, x: int) -> int:
|
|
return abs(end[0] - y) + abs(end[1] - x)
|
|
|
|
# (distance + heuristic, distance, (start_pos, cycle))
|
|
queue = [
|
|
(heuristic(start[0], start[1]), 0, ((start[0], start[1]), start_cycle))
|
|
]
|
|
visited: set[tuple[tuple[int, int], int]] = set()
|
|
distances: dict[tuple[int, int], dict[int, int]] = defaultdict(lambda: {})
|
|
|
|
while queue:
|
|
_, distance, ((y, x), cycle) = heapq.heappop(queue)
|
|
|
|
if ((y, x), cycle) in visited:
|
|
continue
|
|
|
|
distances[y, x][cycle] = distance
|
|
|
|
visited.add(((y, x), cycle))
|
|
|
|
if (y, x) == (end[0], end[1]):
|
|
break
|
|
|
|
for dy, dx in (0, 0), (-1, 0), (1, 0), (0, -1), (0, 1):
|
|
ty = y + dy
|
|
tx = x + dx
|
|
|
|
n_cycle = (cycle + 1) % CYCLE
|
|
|
|
if (ty, tx) == end:
|
|
heapq.heappush(
|
|
queue, (distance + 1, distance + 1, ((ty, tx), n_cycle))
|
|
)
|
|
break
|
|
|
|
if ((ty, tx), n_cycle) in visited:
|
|
continue
|
|
|
|
if (ty, tx) != start and (
|
|
ty < 0 or tx < 0 or ty >= n_rows or tx >= n_cols
|
|
):
|
|
continue
|
|
|
|
if (ty, tx) != start:
|
|
if (ty - n_cycle) % n_rows in south_winds[tx]:
|
|
continue
|
|
if (ty + n_cycle) % n_rows in north_winds[tx]:
|
|
continue
|
|
if (tx + n_cycle) % n_cols in west_winds[ty]:
|
|
continue
|
|
if (tx - n_cycle) % n_cols in east_winds[ty]:
|
|
continue
|
|
|
|
heapq.heappush(
|
|
queue,
|
|
(
|
|
(
|
|
heuristic(ty, tx) + distance + 1,
|
|
distance + 1,
|
|
((ty, tx), n_cycle),
|
|
)
|
|
),
|
|
)
|
|
|
|
return distances, next(iter(distances[end].values()))
|
|
|
|
start = (
|
|
-1,
|
|
next(j for j in range(1, len(lines[0]) - 1) if lines[0][j] == ".") - 1,
|
|
)
|
|
end = (
|
|
n_rows,
|
|
next(j for j in range(1, len(lines[-1]) - 1) if lines[-1][j] == ".") - 1,
|
|
)
|
|
|
|
distances_1, forward_1 = run(start, 0, end)
|
|
yield forward_1
|
|
|
|
distances_2, return_1 = run(end, next(iter(distances_1[end].keys())), start)
|
|
_distances_3, forward_2 = run(start, next(iter(distances_2[start].keys())), end)
|
|
yield forward_1 + return_1 + forward_2
|