diff --git a/2023/day5.py b/2023/day5.py index 147120b..33d84c5 100644 --- a/2023/day5.py +++ b/2023/day5.py @@ -1,7 +1,5 @@ import sys -from collections import defaultdict -from dataclasses import dataclass -from typing import Literal, TypeAlias +from typing import Sequence MAP_ORDER = [ "seed", @@ -16,89 +14,116 @@ MAP_ORDER = [ lines = sys.stdin.read().splitlines() -seeds: list[int] = [] +# mappings from one category to another, each list contains +# ranges stored as (source, target, length), ordered by start and +# completed to have no "hole" maps: dict[tuple[str, str], list[tuple[int, int, int]]] = {} # parsing index = 2 while index < len(lines): - l0 = lines[index] - p1, _, p2 = l0.split("-") - p2 = p2.split()[0].strip() + p1, _, p2 = lines[index].split()[0].split("-") + # extract the existing ranges from the file - we store as (source, target, length) + # whereas the file is in order (target, source, length) index += 1 - maps[p1, p2] = [] + values: list[tuple[int, int, int]] = [] while index < len(lines) and lines[index]: n1, n2, n3 = lines[index].split() - maps[p1, p2].append((int(n1), int(n2), int(n3))) + values.append((int(n2), int(n1), int(n3))) index += 1 + # sort by source value + values.sort() + + # add a 'fake' interval starting at 0 if missing + if values[0][0] != 0: + values.insert(0, (0, 0, values[0][0])) + + # fill gaps between intervals + for i in range(len(values) - 1): + next_start = values[i + 1][0] + end = values[i][0] + values[i][2] + if next_start != end: + values.insert( + i + 1, + (end, end, next_start - end), + ) + + # add an interval covering values up to at least 2**32 at the end + last_start, _, last_length = values[-1] + values.append((last_start + last_length, last_start + last_length, 2**32)) + + assert all(v1[0] + v1[2] == v2[0] for v1, v2 in zip(values[:-1], values[1:])) + assert values[0][0] == 0 + assert values[-1][0] + values[-1][-1] >= 2**32 + + maps[p1, p2] = values index += 1 -def find_location(seed: int) -> int: - value = seed - for map1, map2 in zip(MAP_ORDER[:-1], MAP_ORDER[1:]): - for target, start, length in maps[map1, map2]: - if value >= start and value < start + length: - value = target + (value - start) - break - return value - - def find_range( values: tuple[int, int], map: list[tuple[int, int, int]] ) -> list[tuple[int, int]]: + """ + Given an input range, use the given mapping to find the corresponding list of + ranges in the target domain. + """ r_start, r_length = values ranges: list[tuple[int, int]] = [] - print(r_start, r_length) - for target, start, length in map: - # start is in the range - if start <= r_start and r_start < start + length: - if r_start + r_length < start + length: - ranges.append( - (target + r_start - start, r_length) - ) - else: - ranges.append( - (target + r_start - start, length - (r_start - start)) - ) - elif start < r_start: - if r_start + r_length < start + length: - ranges.append( - (target + r_start - start, target + r_start - start + r_length) - ) - elif start >= r_start and r_start >= start - if r_start <= start and start < start + length: - print(start, length, target) - if r_start + r_length < start + length: - ranges.append( - (target + (start - r_start), target + (start - r_start) + length) - ) - else: - ranges.append((target + (start - r_start), target + length)) + # find index of the first and last intervals in map that overlaps the input + # interval + index_start, index_end = -1, -1 + + for index_start, (start, _, length) in enumerate(map): + if start <= r_start and start + length > r_start: + break + + for index_end, (start, _, length) in enumerate( + map[index_start:], start=index_start + ): + if r_start + r_length >= start and r_start + r_length < start + length: + break + + assert index_start >= 0 and index_end >= 0 + + # special case if one interval contains everything + if index_start == index_end: + start, target, length = map[index_start] + ranges.append((target + r_start - start, r_length)) + else: + # add the start interval part + start, target, length = map[index_start] + ranges.append((target + r_start - start, start + length - r_start)) + + # add all intervals between the first and last (excluding both) + index = index_start + 1 + while index < index_end: + start, target, length = map[index] + ranges.append((target, length)) + index += 1 + + # add the last interval + start, target, length = map[index_end] + ranges.append((target, r_start + r_length - start)) return ranges -# part 1 -seeds = [int(s) for s in lines[0].split(":")[1].strip().split()] -answer_1 = min(find_location(seed) for seed in seeds) +def find_location_ranges(seeds: Sequence[tuple[int, int]]) -> Sequence[tuple[int, int]]: + for map1, map2 in zip(MAP_ORDER[:-1], MAP_ORDER[1:]): + seeds = [s2 for s1 in seeds for s2 in find_range(s1, maps[map1, map2])] + return seeds + + +# part 1 - use find_range() with range of length 1 +seeds_p1 = [(int(s), 1) for s in lines[0].split(":")[1].strip().split()] +answer_1 = min(start for start, _ in find_location_ranges(seeds_p1)) print(f"answer 1 is {answer_1}") -# part 2 +# # part 2 parts = lines[0].split(":")[1].strip().split() seeds_p2 = [(int(s), int(e)) for s, e in zip(parts[::2], parts[1::2])] - -for seed in range(seeds_p2[0][0], seeds_p2[0][0] + seeds_p2[0][1]): - print(seed, find_location(seed)) -print("---") - -seeds_p2 = [seeds_p2[0]] -for map1, map2 in zip(MAP_ORDER[:-1], MAP_ORDER[1:]): - seeds_p2 = [s2 for s1 in seeds_p2 for s2 in find_range(s1, maps[map1, map2])] -print(seeds_p2) - -answer_2 = ... +answer_2 = min(start for start, _ in find_location_ranges(seeds_p2)) print(f"answer 2 is {answer_2}")