from typing import Any, Iterator import numpy as np from sympy import solve, symbols from ..base import BaseSolver class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: lines = input.splitlines() positions = np.array( [[int(c) for c in line.split("@")[0].strip().split(", ")] for line in lines] ) velocities = np.array( [[int(c) for c in line.split("@")[1].strip().split(", ")] for line in lines] ) # part 1 low, high = ( [7, 27] if len(positions) <= 10 else [200000000000000, 400000000000000] ) count = 0 for i1, (p1, v1) in enumerate(zip(positions, velocities)): p, r = p1[:2], v1[:2] q, s = positions[i1 + 1 :, :2], velocities[i1 + 1 :, :2] rs = np.cross(r, s) q, s, rs = q[m := (rs != 0)], s[m], rs[m] t = np.cross((q - p), s) / rs u = np.cross((q - p), r) / rs t, u = t[m := ((t >= 0) & (u >= 0))], u[m] c = p + np.expand_dims(t, 1) * r count += np.all((low <= c) & (c <= high), axis=1).sum() yield count # part 2 # equation # p1 + t1 * v1 == p0 + t1 * v0 # p2 + t2 * v2 == p0 + t2 * v0 # p3 + t3 * v3 == p0 + t3 * v0 # ... # pn + tn * vn == p0 + tn * v0 # # we can solve with only 3 lines since each lines contains 3 # equations (x / y / z), so 3 lines give 9 equations and 9 # variables: position (3), velocities (3) and times (3). n = 3 x, y, z, vx, vy, vz, *ts = symbols( "x y z vx vy vz " + " ".join(f"t{i}" for i in range(n + 1)) ) equations = [] for i1, ti in zip(range(n), ts): for p, d, pi, di in zip( (x, y, z), (vx, vy, vz), positions[i1], velocities[i1] ): equations.append(p + ti * d - pi - ti * di) r = solve(equations, [x, y, z, vx, vy, vz] + list(ts), dict=True)[0] yield r[x] + r[y] + r[z]