import sys from collections import Counter, defaultdict class HandTypes: HIGH_CARD = 0 ONE_PAIR = 1 TWO_PAIR = 2 THREE_OF_A_KIND = 3 FULL_HOUSE = 4 FOUR_OF_A_KIND = 5 FIVE_OF_A_KIND = 6 # mapping from number of different cards + highest count of card to the hand type LEN_LAST_TO_TYPE: dict[int, dict[int, int]] = { 1: defaultdict(lambda: HandTypes.FIVE_OF_A_KIND), 2: defaultdict(lambda: HandTypes.FULL_HOUSE, {4: HandTypes.FOUR_OF_A_KIND}), 3: defaultdict(lambda: HandTypes.TWO_PAIR, {3: HandTypes.THREE_OF_A_KIND}), 4: defaultdict(lambda: HandTypes.ONE_PAIR), 5: defaultdict(lambda: HandTypes.HIGH_CARD), } def extract_key(hand: str, values: dict[str, int], joker: str = "0") -> tuple[int, ...]: # get count of each cards, in increasing count - the 'or' part handles the JJJJJ # case when joker is True cnt = sorted(Counter(hand.replace(joker, "")).values()) or [0] return (LEN_LAST_TO_TYPE[len(cnt)][cnt[-1] + hand.count(joker)],) + tuple( values[c] for c in hand ) lines = sys.stdin.read().splitlines() cards = [(t[0], int(t[1])) for line in lines if (t := line.split())] # part 1 values = {card: value for value, card in enumerate("23456789TJQKA")} cards.sort(key=lambda cv: extract_key(cv[0], values=values)) answer_1 = sum(rank * value for rank, (_, value) in enumerate(cards, start=1)) print(f"answer 1 is {answer_1}") # part 2 values = {card: value for value, card in enumerate("J23456789TQKA")} cards.sort(key=lambda cv: extract_key(cv[0], values=values, joker="J")) answer_2 = sum(rank * value for rank, (_, value) in enumerate(cards, start=1)) print(f"answer 2 is {answer_2}")