diff --git a/aoc/src/dev/ctsk/aoc/days/Day11.scala b/aoc/src/dev/ctsk/aoc/days/Day11.scala index f459fdd..01b6dde 100644 --- a/aoc/src/dev/ctsk/aoc/days/Day11.scala +++ b/aoc/src/dev/ctsk/aoc/days/Day11.scala @@ -11,26 +11,28 @@ object Day11 extends Solver(11): val mod = Math.pow(10, numDigits / 2).toLong Some(n % mod, n / mod) - private val memo = MuMap.empty[(Long, Int), Long] - private def count(initial: Long, numBlinks: Int): Long = - def rec(stone: Long, blinks: Int): Long = - if blinks == 0 then return 1L - memo.getOrElseUpdate( - (stone, blinks), { - if stone == 0 then rec(1L, blinks - 1) - else - halves(stone) match - case Some(a, b) => - rec(a, blinks - 1) + rec(b, blinks - 1) - case None => rec(stone * 2024, blinks - 1) - } - ) - rec(initial, numBlinks) + private def count(initial: Seq[Long], depth: Int): Long = + def step(stones: MuMap[Long, Long]): MuMap[Long, Long] = + val next = MuMap.empty[Long, Long].withDefaultValue(0L) + for ((stone, count) <- stones) { + if stone == 0 then next(1) += count + else + halves(stone) match + case Some((a, b)) => + next(a) += count + next(b) += count + case None => + next(stone * 2024) += count + } + next + + val initMap = MuMap.from(initial.map((_, 1L))) + Seq.iterate(initMap, depth + 1)(step).last.values.sum def run(input: os.ReadablePath): (Timings, Solution) = val (pre_time, in) = timed { longs(os.read.lines(input).head) } - val (p1_time, p1_solution) = timed { in.map(v => count(v, 25)).sum } - val (p2_time, p2_solution) = timed { in.map(v => count(v, 75)).sum } + val (p1_time, p1_solution) = timed { count(in, 25) } + val (p2_time, p2_solution) = timed { count(in, 75) } ( Timings(pre_time, p1_time, p2_time),