import sys import numpy as np lines = sys.stdin.read().splitlines() data = np.array([[c == "#" for c in line] for line in lines]) rows = {c for c in range(data.shape[0]) if not data[c, :].any()} columns = {c for c in range(data.shape[1]) if not data[:, c].any()} galaxies_y, galaxies_x = np.where(data) # type: ignore def compute_total_distance(expansion: int) -> int: distances: list[int] = [] for g1 in range(len(galaxies_y)): x1, y1 = int(galaxies_x[g1]), int(galaxies_y[g1]) for g2 in range(g1 + 1, len(galaxies_y)): x2, y2 = int(galaxies_x[g2]), int(galaxies_y[g2]) dx = sum( 1 + (expansion - 1) * (x in columns) for x in range(min(x1, x2), max(x1, x2)) ) dy = sum( 1 + (expansion - 1) * (y in rows) for y in range(min(y1, y2), max(y1, y2)) ) distances.append(dx + dy) return sum(distances) # part 1 answer_1 = compute_total_distance(2) print(f"answer 1 is {answer_1}") # part 2 answer_2 = compute_total_distance(1000000) print(f"answer 2 is {answer_2}")