99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
import heapq
|
|
import math
|
|
import sys
|
|
from collections import defaultdict
|
|
|
|
lines = sys.stdin.read().splitlines()
|
|
|
|
winds = {
|
|
(i - 1, j - 1, lines[i][j])
|
|
for i in range(1, len(lines) - 1)
|
|
for j in range(1, len(lines[i]) - 1)
|
|
if lines[i][j] != "."
|
|
}
|
|
|
|
n_rows, n_cols = len(lines) - 2, len(lines[0]) - 2
|
|
CYCLE = math.lcm(n_rows, n_cols)
|
|
|
|
east_winds = [{j for j in range(n_cols) if (i, j, ">") in winds} for i in range(n_rows)]
|
|
west_winds = [{j for j in range(n_cols) if (i, j, "<") in winds} for i in range(n_rows)]
|
|
north_winds = [
|
|
{i for i in range(n_rows) if (i, j, "^") in winds} for j in range(n_cols)
|
|
]
|
|
south_winds = [
|
|
{i for i in range(n_rows) if (i, j, "v") in winds} for j in range(n_cols)
|
|
]
|
|
|
|
|
|
def run(start: tuple[int, int], start_cycle: int, end: tuple[int, int]):
|
|
def heuristic(y: int, x: int) -> int:
|
|
return abs(end[0] - y) + abs(end[1] - x)
|
|
|
|
# (distance + heuristic, distance, (start_pos, cycle))
|
|
queue = [(heuristic(start[0], start[1]), 0, ((start[0], start[1]), start_cycle))]
|
|
visited: set[tuple[tuple[int, int], int]] = set()
|
|
distances: dict[tuple[int, int], dict[int, int]] = defaultdict(lambda: {})
|
|
|
|
while queue:
|
|
_, distance, ((y, x), cycle) = heapq.heappop(queue)
|
|
|
|
if ((y, x), cycle) in visited:
|
|
continue
|
|
|
|
distances[y, x][cycle] = distance
|
|
|
|
visited.add(((y, x), cycle))
|
|
|
|
if (y, x) == (end[0], end[1]):
|
|
break
|
|
|
|
for dy, dx in (0, 0), (-1, 0), (1, 0), (0, -1), (0, 1):
|
|
ty = y + dy
|
|
tx = x + dx
|
|
|
|
n_cycle = (cycle + 1) % CYCLE
|
|
|
|
if (ty, tx) == end:
|
|
heapq.heappush(queue, (distance + 1, distance + 1, ((ty, tx), n_cycle)))
|
|
break
|
|
|
|
if ((ty, tx), n_cycle) in visited:
|
|
continue
|
|
|
|
if (ty, tx) != start and (ty < 0 or tx < 0 or ty >= n_rows or tx >= n_cols):
|
|
continue
|
|
|
|
if (ty, tx) != start:
|
|
if (ty - n_cycle) % n_rows in south_winds[tx]:
|
|
continue
|
|
if (ty + n_cycle) % n_rows in north_winds[tx]:
|
|
continue
|
|
if (tx + n_cycle) % n_cols in west_winds[ty]:
|
|
continue
|
|
if (tx - n_cycle) % n_cols in east_winds[ty]:
|
|
continue
|
|
|
|
heapq.heappush(
|
|
queue,
|
|
((heuristic(ty, tx) + distance + 1, distance + 1, ((ty, tx), n_cycle))),
|
|
)
|
|
|
|
return distances, next(iter(distances[end].values()))
|
|
|
|
|
|
start = (
|
|
-1,
|
|
next(j for j in range(1, len(lines[0]) - 1) if lines[0][j] == ".") - 1,
|
|
)
|
|
end = (
|
|
n_rows,
|
|
next(j for j in range(1, len(lines[-1]) - 1) if lines[-1][j] == ".") - 1,
|
|
)
|
|
|
|
distances_1, forward_1 = run(start, 0, end)
|
|
print(f"answer 1 is {forward_1}")
|
|
|
|
distances_2, return_1 = run(end, next(iter(distances_1[end].keys())), start)
|
|
distances_3, forward_2 = run(start, next(iter(distances_2[start].keys())), end)
|
|
print(f"answer 2 is {forward_1 + return_1 + forward_2}")
|