Co-authored-by: Mikael CAPELLE <mikael.capelle@thalesaleniaspace.com> Co-authored-by: Mikaël Capelle <capelle.mikael@gmail.com> Reviewed-on: #3
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
import json
|
|
from functools import cmp_to_key
|
|
from typing import Any, Iterator, TypeAlias, cast
|
|
|
|
from ..base import BaseSolver
|
|
|
|
Packet: TypeAlias = list[int | list["Packet"]]
|
|
|
|
|
|
def compare(lhs: Packet, rhs: Packet) -> int:
|
|
for lhs_a, rhs_a in zip(lhs, rhs):
|
|
if isinstance(lhs_a, int) and isinstance(rhs_a, int):
|
|
if lhs_a != rhs_a:
|
|
return rhs_a - lhs_a
|
|
else:
|
|
if not isinstance(lhs_a, list):
|
|
lhs_a = [lhs_a] # type: ignore
|
|
elif not isinstance(rhs_a, list):
|
|
rhs_a = [rhs_a] # type: ignore
|
|
assert isinstance(rhs_a, list) and isinstance(lhs_a, list)
|
|
r = compare(cast(Packet, lhs_a), cast(Packet, rhs_a))
|
|
if r != 0:
|
|
return r
|
|
|
|
return len(rhs) - len(lhs)
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
def solve(self, input: str) -> Iterator[Any]:
|
|
blocks = input.split("\n\n")
|
|
pairs = [tuple(json.loads(p) for p in block.split("\n")) for block in blocks]
|
|
|
|
yield sum(i + 1 for i, (lhs, rhs) in enumerate(pairs) if compare(lhs, rhs) > 0)
|
|
|
|
dividers = [[[2]], [[6]]]
|
|
|
|
packets = [packet for packets in pairs for packet in packets]
|
|
packets.extend(dividers)
|
|
packets = list(reversed(sorted(packets, key=cmp_to_key(compare))))
|
|
|
|
d_index = [packets.index(d) + 1 for d in dividers]
|
|
yield d_index[0] * d_index[1]
|