import itertools from collections import defaultdict from typing import Any, Iterator, Literal, cast import parse # type: ignore from ..base import BaseSolver def max_change_in_happiness(happiness: dict[str, dict[str, int]]) -> int: guests = list(happiness) return max( sum( happiness[o][d] + happiness[d][o] for o, d in zip((guests[0],) + order, order + (guests[0],)) ) for order in map(tuple, itertools.permutations(guests[1:])) ) class Solver(BaseSolver): def solve(self, input: str) -> Iterator[Any]: lines = input.splitlines() happiness: dict[str, dict[str, int]] = defaultdict(dict) for line in lines: u1, gain_or_loose, hap, u2 = cast( tuple[str, Literal["gain", "lose"], int, str], parse.parse( # type: ignore "{} would {} {:d} happiness units by sitting next to {}.", line ), ) happiness[u1][u2] = hap if gain_or_loose == "gain" else -hap yield max_change_in_happiness(happiness) for guest in list(happiness): happiness["me"][guest] = 0 happiness[guest]["me"] = 0 yield max_change_in_happiness(happiness)