diff --git a/src/holt59/aoc/2024/day4.py b/src/holt59/aoc/2024/day4.py index 64188be..656ed2e 100644 --- a/src/holt59/aoc/2024/day4.py +++ b/src/holt59/aoc/2024/day4.py @@ -1,45 +1,30 @@ +import itertools as it import sys -import numpy as np +lines = sys.stdin.read().strip().splitlines() +n = len(lines) -lines = np.array(list(map(list, sys.stdin.read().strip().splitlines()))) -print(lines) -n = lines.shape[0] - -answer_1 = 0 -for i in range(n): - line = "".join(lines[i, :]) - answer_1 += line.count("XMAS") + line.count("SAMX") - - column = "".join(lines[:, i]) - answer_1 += column.count("XMAS") + column.count("SAMX") - - diag = "".join(np.diagonal(lines, i)) - answer_1 += diag.count("XMAS") + diag.count("SAMX") - - diag = "".join(np.diagonal(lines[::-1, :], i)) - answer_1 += diag.count("XMAS") + diag.count("SAMX") - - if i != 0: - diag = "".join(np.diagonal(lines, -i)) - answer_1 += diag.count("XMAS") + diag.count("SAMX") - - diag = "".join(np.diagonal(lines[::-1, :], -i)) - answer_1 += diag.count("XMAS") + diag.count("SAMX") +answer_1 = sum( + line.count("XMAS") + line.count("SAMX") + for i in range(n) + for ri, rk, ro, ci, ck, cm in ( + (1, 0, 0, 0, 1, n), + (0, 1, 0, 1, 0, n), + (0, 1, 0, 1, 1, n - i), + (0, -1, -1, 1, 1, n - i), + (1, 1, 0, 0, 1, n - i if i != 0 else 0), + (-1, -1, -1, 0, 1, n - i if i != 0 else 0), + ) + if ( + line := "".join(lines[ri * i + rk * k + ro][ci * i + ck * k] for k in range(cm)) + ) +) answer_2 = sum( - "".join( - ( - lines[i - 1, j - 1], - lines[i - 1, j + 1], - lines[i + 1, j - 1], - lines[i + 1, j + 1], - ) - ) + lines[i][j] == "A" + and "".join(lines[i + di][j + dj] for di, dj in it.product((-1, 1), (-1, 1))) in {"MSMS", "SSMM", "MMSS", "SMSM"} - for i in range(1, n - 1) - for j in range(1, n - 1) - if lines[i, j] == "A" + for i, j in it.product(range(1, n - 1), range(1, n - 1)) ) print(f"answer 1 is {answer_1}")