import itertools as it import sys from typing import TypeAlias NodeType: TypeAlias = tuple[tuple[int, int], tuple[int, int]] EdgesType: TypeAlias = dict[NodeType, tuple[NodeType, set[tuple[int, int]]]] ROTATE = {(-1, 0): (0, 1), (0, 1): (1, 0), (1, 0): (0, -1), (0, -1): (-1, 0)} START_NODE: NodeType = ((-2, -2), (-1, 0)) FINAL_POS: tuple[int, int] = (-1, -1) def move( lines: list[str], pos: tuple[int, int], dir: tuple[int, int] ) -> tuple[tuple[int, int] | None, set[tuple[int, int]]]: n_rows, n_cols = len(lines), len(lines[0]) row, col = pos marked: set[tuple[int, int]] = set() final_pos: tuple[int, int] | None = None while True: marked.add((row, col)) if not (0 <= row + dir[0] < n_rows and 0 <= col + dir[1] < n_cols): final_pos = None break if lines[row + dir[0]][col + dir[1]] != ".": final_pos = (row, col) break row += dir[0] col += dir[1] return final_pos, marked def compute_graph(lines: list[str], start_node: NodeType): n_rows, n_cols = len(lines), len(lines[0]) edges: EdgesType = {} start_pos, start_dir = start_node end_pos, marked = move(lines, start_pos, start_dir) assert end_pos is not None edges[START_NODE] = ((end_pos, start_dir), marked) for row, col in it.product(range(n_rows), range(n_cols)): if lines[row][col] != "#": continue for start_pos, start_dir in ( ((row - 1, col), (1, 0)), ((row + 1, col), (-1, 0)), ((row, col - 1), (0, 1)), ((row, col + 1), (0, -1)), ): if 0 <= start_pos[0] < n_rows and 0 <= start_pos[1] < n_cols: end_pos, marked = move(lines, start_pos, ROTATE[start_dir]) edges[start_pos, start_dir] = ( (end_pos or FINAL_POS, ROTATE[start_dir]), marked, ) return edges def is_loop(lines: list[str], edges: EdgesType, position: tuple[int, int]): row, col = position current_node = START_NODE found: set[NodeType] = set() while current_node[0] != FINAL_POS and current_node not in found: found.add(current_node) target_node, edge_marked = edges[current_node] if (row, col) in edge_marked: # need to break the edge target_dir = target_node[1] end_pos, _ = move( lines, (row - target_dir[0], col - target_dir[1]), ROTATE[target_dir] ) current_node = (end_pos or FINAL_POS, ROTATE[target_dir]) else: current_node = target_node return current_node in found def print_grid( lines: list[str], marked: set[tuple[int, int]], current_pos: tuple[int, int] | None ): chars = list(map(list, lines)) for i, j in marked: chars[i][j] = "X" if current_pos: chars[current_pos[0]][current_pos[1]] = "T" for line in chars: print("".join(line)) print() # read lines lines = sys.stdin.read().splitlines() # find and delete original position start_pos = next( (i, j) for i, row in enumerate(lines) for j, col in enumerate(row) if col == "^" ) lines[start_pos[0]] = lines[start_pos[0]].replace("^", ".") # compute edges from the map edges = compute_graph(lines, (start_pos, (-1, 0))) # part 1 marked: set[tuple[int, int]] = set() current_node = START_NODE while current_node[0] != FINAL_POS: current_node, current_marked = edges[current_node] marked = marked.union(current_marked) answer_1 = len(marked) print(f"answer 1 is {answer_1}") answer_2 = sum(is_loop(lines, edges, pos) for pos in marked if pos != start_pos) print(f"answer 2 is {answer_2}")