# -*- encoding: utf-8 -*- from __future__ import annotations import heapq import itertools import re import sys import time as time_p from collections import defaultdict from typing import FrozenSet, NamedTuple from tqdm import tqdm, trange class Pipe(NamedTuple): name: str flow: int tunnels: list[str] def __lt__(self, other: object) -> bool: return isinstance(other, Pipe) and other.name < self.name def __eq__(self, other: object) -> bool: return isinstance(other, Pipe) and other.name == self.name def __hash__(self) -> int: return hash(self.name) def __str__(self) -> str: return self.name def __repr__(self) -> str: return self.name def breadth_first_search(pipes: dict[str, Pipe], pipe: Pipe) -> dict[Pipe, int]: """ Runs a BFS from the given pipe and return the shortest distance (in term of hops) to all other pipes. """ queue = [(0, pipe_1)] visited = set() distances: dict[Pipe, int] = {} while len(distances) < len(pipes): distance, current = heapq.heappop(queue) if current in visited: continue visited.add(current) distances[current] = distance for tunnel in current.tunnels: heapq.heappush(queue, (distance + 1, pipes[tunnel])) return distances def update_with_better( node_at_times: dict[FrozenSet[Pipe], int], flow: int, flowing: FrozenSet[Pipe] ) -> None: node_at_times[flowing] = max(node_at_times[flowing], flow) def part_1( start_pipe: Pipe, max_time: int, distances: dict[tuple[Pipe, Pipe], int], relevant_pipes: FrozenSet[Pipe], ): node_at_times: dict[int, dict[Pipe, dict[FrozenSet[Pipe], int]]] = defaultdict( lambda: defaultdict(lambda: defaultdict(lambda: 0)) ) node_at_times[0] = {start_pipe: {frozenset(): 0}} for time in range(max_time): for c_pipe, nodes in node_at_times[time].items(): for flowing, flow in nodes.items(): for target in relevant_pipes: distance = distances[c_pipe, target] + 1 if time + distance >= max_time or target in flowing: continue update_with_better( node_at_times[time + distance][target], flow + sum(pipe.flow for pipe in flowing) * distance, flowing | {target}, ) update_with_better( node_at_times[max_time][c_pipe], flow + sum(pipe.flow for pipe in flowing) * (max_time - time), flowing, ) return max( flow for nodes_of_pipe in node_at_times[max_time].values() for flow in nodes_of_pipe.values() ) def part_2( start_pipe: Pipe, max_time: int, pipes: dict[str, Pipe], relevant_pipes: FrozenSet[Pipe], distances: dict[tuple[Pipe, Pipe], int], ): node_at_times: dict[ int, dict[tuple[Pipe, Pipe], dict[FrozenSet[Pipe], int]] ] = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))) node_at_times[0] = {(start_pipe, start_pipe): {frozenset(): 0}} # map node + distance to d1, d2, d3, d4 = 0, 0, 0, 0 best_flow = 0 for time in range(max_time): print( f"{time + 1:2d}/{max_time} - {best_flow:4d} - " f"{sum(map(len, node_at_times[time].values())):7d} - " f"{d1:.3f} {d2:.3f} {d3:.3f} {d4:.3f}" ) d1, d2, d3, d4 = 0, 0, 0, 0 for (c_pipe, e_pipe), nodes in node_at_times[time].items(): for flowing, flow in nodes.items(): t1 = time_p.time() c_best_flow = ( flow + sum(pipe.flow for pipe in flowing) * (max_time - time) + sum( ( pipe.flow * ( max_time - time - 1 - min(distances[c_pipe, pipe], distances[e_pipe, pipe]) ) for pipe in relevant_pipes if pipe not in flowing ), start=0, ) ) d1 += time_p.time() - t1 if c_best_flow < best_flow: continue best_flow = max( best_flow, flow + sum(pipe.flow for pipe in flowing) * (max_time - time), ) t1 = time_p.time() if flowing != relevant_pipes: for c_next_s, e_next_s in itertools.product( c_pipe.tunnels, e_pipe.tunnels ): c_next = pipes[c_next_s] e_next = pipes[e_next_s] update_with_better( node_at_times[time + 1][c_next, e_next], flow + sum(pipe.flow for pipe in flowing), flowing, ) d2 += time_p.time() - t1 t1 = time_p.time() if c_pipe in relevant_pipes and c_pipe not in flowing: for e_next_s in e_pipe.tunnels: e_next = pipes[e_next_s] update_with_better( node_at_times[time + 1][c_pipe, e_next], flow + sum(pipe.flow for pipe in flowing), flowing | {c_pipe}, ) if e_pipe in relevant_pipes and e_pipe not in flowing: for c_next_s in c_pipe.tunnels: c_next = pipes[c_next_s] update_with_better( node_at_times[time + 1][c_next, e_pipe], flow + sum(pipe.flow for pipe in flowing), flowing | {e_pipe}, ) if ( e_pipe in relevant_pipes and c_pipe in relevant_pipes and e_pipe not in flowing and c_pipe not in flowing ): update_with_better( node_at_times[time + 1][c_pipe, e_pipe], flow + sum(pipe.flow for pipe in flowing), flowing | {c_pipe, e_pipe}, ) update_with_better( node_at_times[max_time][c_pipe, e_pipe], flow + sum(pipe.flow for pipe in flowing) * (max_time - time), flowing, ) d3 += time_p.time() - t1 return max( flow for nodes_of_pipe in node_at_times[max_time].values() for flow in nodes_of_pipe.values() ) # === MAIN === lines = sys.stdin.read().splitlines() pipes: dict[str, Pipe] = {} for line in lines: r = re.match( R"Valve ([A-Z]+) has flow rate=([0-9]+); tunnels? leads? to valves? (.+)", line, ) assert r g = r.groups() pipes[g[0]] = Pipe(g[0], int(g[1]), g[2].split(", ")) # compute distances from one valve to any other distances: dict[tuple[Pipe, Pipe], int] = {} for pipe_1 in pipes.values(): distances.update( { (pipe_1, pipe_2): distance for pipe_2, distance in breadth_first_search(pipes, pipe_1).items() } ) # valves with flow relevant_pipes = frozenset(pipe for pipe in pipes.values() if pipe.flow > 0) # 1651, 1653 print(part_1(pipes["AA"], 30, distances, relevant_pipes)) # 1707, 2223 print(part_2(pipes["AA"], 26, pipes, relevant_pipes, distances))