Day 12: Hot Springs
by @mbovel
Puzzle description
https://adventofcode.com/2023/day/12
Scaffold
Let's create a folder with a file day12.scala
to hold the core of our code.
We start by writing two examples and defining three functions countAll
, countRow
, and count
that we will implement later:
//> using scala 3.3.1
import scala.io.Source
/** The example puzzle from the problem description. */
val examplePuzzle = IArray(
"???.### 1,1,3",
".??..??...?##. 1,1,3",
"?#?#?#?#?#?#?#? 1,3,1,6",
"????.#...#... 4,1,1",
"????.######..#####. 1,6,5",
"###???????? 3,2,1"
)
/** Our personal puzzle input. */
val personalPuzzle = Source.fromFile("input.txt").mkString.trim()
/** Entry point for part 1. */
@main def part1(): Unit = println(countAll(personalPuzzle))
/** Sums `countRow` over all rows in `input`. */
def countAll(input: String): Long = ???
/** Counts all of the different valid arrangements of
* operational and broken springs in the given row.
*/
def countRow(input: String): Long = ???
/** Helper recursive function for `countRow` that does
* the actual work.
*
* @param input
* the remaining input to process
* @param ds
* a list of the numbers of damaged springs remaining to be placed
*/
def count(input: List[Char], ds: List[Int]): Long = ???
Thanks to scala-cli we can run this file with:
$ scala-cli -M part1 .
Tests
In the same folder, we create a file day12.test.scala
to hold our tests. We write one test for each individual row of the example from the instructions, a test for the whole example, and a test for our personal puzzle input:
//> using scala 3.3.1
//> using test.dep org.scalameta::munit::1.0.0-M10
class Day12Test extends munit.FunSuite:
test("example row 1"):
assertEquals(countRow(examplePuzzle(0)), 1L)
test("example row 2"):
assertEquals(countRow(examplePuzzle(1)), 4L)
test("example row 3"):
assertEquals(countRow(examplePuzzle(2)), 1L)
test("example row 4"):
assertEquals(countRow(examplePuzzle(3)), 1L)
test("example row 5"):
assertEquals(countRow(examplePuzzle(4)), 4L)
test("example row 6"):
assertEquals(countRow(examplePuzzle(5)), 10L)
test("example"):
assertEquals(countAll(examplePuzzle.mkString("\n")), 21L)
test("puzzle input"):
assertEquals(countAll(personalPuzzle), 7118L)
We can run the tests with:
$ scala-cli test .
Part 1
Implementation of countAll
and countRow
countAll
and countRow
can be implemented concisely using split
and map
:
def countAll(input: String): Long = input.split("\n").map(countRow).sum
def countRow(input: String): Long =
val Array(conditions, damagedCounts) = input.split(" ")
count(
conditions.toList,
damagedCounts.split(",").map(_.toInt).toList
)
Character-level implementation of count
For our first implementation, we'll iterate through the input string character by character, and we'll use an additional parameter d
to keep track of the number of consecutive damaged springs seen so far:
/** Helper recursive function for `countRow` that does the actual work.
*
* @param input
* the remaining input to process
* @param ds
* a list of the numbers of damaged springs remaining to be placed
* @param d
* the number of consecutive damaged springs seen so far
*/
def count(input: List[Char], ds: List[Int], d: Int = 0): Long =
// We've reached the end of the input.
if input.isEmpty then
// This is a valid arrangement if there are no sequences of
// damaged springs left to place (ds.isEmpty) and we're
// not currently in a sequence of damaged springs (d == 0).
if ds.isEmpty && d == 0 then 1L
// This is also a valid arrangement if there is one sequence
// of damaged springs left to place (ds.length == 1) and its
// size is d (ds.head == d).
else if ds.length == 1 && ds.head == d then 1L
// Otherwise, this is not a valid arrangement.
else 0
else
def operationalCase() =
// If we're not currently in a sequence of damaged springs,
// then we can consume an operational spring.
if d == 0 then count(input.tail, ds, 0)
// We are currently in a sequence of damaged springs,
// which this operational spring ends. If the length
// of the damaged sequence is the expected one, the we can
// continue with the next damaged sequence.
else if !ds.isEmpty && ds.head == d then
count(input.tail, ds.tail, 0)
// Otherwise, this is not a valid arrangement.
else 0L
def damagedCase() =
// If no damaged springs are expected, then this is not a valid
// arrangement.
if ds.isEmpty then 0L
// Optimization: no need to recurse if d becomes greater than the
// expected damaged sequence length.
else if d == ds.head then 0L
// Otherwise, we can consume a damaged spring.
else count(input.tail, ds, d + 1)
input.head match
// If we encounter a question mark, this position can have
// either an operational or a damaged spring.
case '?' => operationalCase() + damagedCase()
// If we encounter a dot, this position has an operational
// spring.
case '.' => operationalCase()
// If we encounter a hash, this position has damaged spring.
case '#' => damagedCase()
Counting calls
The implementation above is correct, but it has an exponential run time complexity: it calls itself up to two times at each step, so the number of calls grows exponentially with the length of the input.
To demonstrate this, we will add a counter ops
that counts the number of calls to count
:
var ops = 0
private def count(input: List[Char], d: Long, ds: List[Long]): Long =
ops += 1
// ... same as before ...
And consider the following example puzzle in addition to our two existing examples:
val slowPuzzleSize = 16
val slowPuzzle =
("??." * slowPuzzleSize) + " " + ("1," * (slowPuzzleSize - 1)) + "1"
To see how many times count
is called for our example puzzles, we add the following function:
@main def countOps =
val puzzles =
IArray(
("example", examplePuzzle.mkString("\n")),
("personal", personalPuzzle),
("slow", slowPuzzle)
)
for (name, input) <- puzzles do
ops = 0
val start = System.nanoTime()
val result = countAll(input)
val end = System.nanoTime()
val elapsed = (end - start) / 1_000_000
println(f"$name%8s: $result%5d ($ops%9d calls, $elapsed%4d ms)")
Running this code gives us the following output:
example: 21 ( 305 calls, 24 ms)
personal: 7118 ( 149712 calls, 37 ms)
slow: 65536 (172186881 calls, 1415 ms)
Memoization
Many of the calls to count
are redundant: they are made with the same arguments as previous calls. A quick way to improve algorithmic complexity and the performance of this function is to memoize it: we can cache the results of previous calls to count
and reuse them when the same arguments are passed again.
We can use a mutable.Map
to store the results of previous calls. We use tuples containing the arguments (input, ds, d)
as keys, and we use the getOrElseUpdate
method to either retrieve the cached result or compute it and store it in the map.
Here is the memoized version of count
:
import scala.collection.mutable
val cache = mutable.Map.empty[(List[Char], List[Int], Long), Long]
private def count(input: List[Char], ds: List[Int], d: Int = 0): Long =
cache.getOrElseUpdate((input, ds, d), countUncached(input, ds, d))
var ops = 0
def countUncached(input: List[Char], ds: List[Int], d: Int = 0): Long =
ops += 1
// ... same as before ...
Running countOps
again, we now get the following output:
example: 21 ( 169 calls, 32 ms)
personal: 7118 ( 38382 calls, 74 ms)
slow: 65536 ( 679 calls, 1 ms)
That's much better! The number of operations is lower, and the running time is faster in the pathological slow
puzzle case.
The number of operations is a good primary metric here, because it is completely deterministic, stable across runs and is a good proxy for the complexity of the algorithm.
We also measure the actual running time of the function as an indicator, but a naive measurement like this is not accurate and can vary a lot between runs. For a more accurate measurement, one could use a benchmarking library such as JMH (for example with the JMH SBT plugin) or ScalaMeter.
Part 2
Oh, there is a second part to this puzzle!
Implementation
The only change needed to implement the second part is to unfold the input rows before counting them. We add the unfoldRow
function to do that, and call it from countAllUnfolded
:
/** Entry point for part 2 */
@main def part2(): Unit =
println(countAllUnfolded(personalPuzzle))
def countAllUnfolded(input: String): Long =
input.split("\n").map(unfoldRow).map(countRow).sum
def unfoldRow(input: String): String =
val Array(conditions, damagedCounts) =
input.split(" ")
val conditionsUnfolded =
(0 until 5).map(_ => conditions).mkString("?")
val damagedCountUnfolded =
(0 until 5).map(_ => damagedCounts).mkString(",")
f"$conditionsUnfolded $damagedCountUnfolded"
Executing part2
with my personal input puzzle runs in ~800 ms on my machine, and countUncached
is called 681'185:
example: 21 ( 169 calls, 31 ms)
personal: 7118 ( 38382 calls, 74 ms)
slow: 65536 ( 679 calls, 1 ms)
personal unfolded: 7030194981795 (681185 calls, 815 ms)
Can we do better?
Group-level implementation of count
Our first implementation of count
works. Recursing character by character through the input string looks like a natural way to solve this problem. But we can simplify the implementation and improve its performance by considering groups instead of individual characters.
To know if a group of damaged springs of length can be at a given position, we can consume the next characters of the input and check if they can all be damaged springs (i.e. none of them is a .
), and if the following character can be an operational spring (i.e. it is not a #
).
import scala.collection.mutable
extension (b: Boolean) private inline def toLong: Long =
if b then 1L else 0L
val cache2 = mutable.Map.empty[(List[Char], List[Int]), Long]
private def count2(input: List[Char], ds: List[Int]): Long =
cache2.getOrElseUpdate((input, ds), count2Uncached(input, ds))
def count2Uncached(input: List[Char], ds: List[Int]): Long =
ops += 1
// We've seen all expected damaged sequences. The arrangement
// is therefore valid only if the input does not contain
// damaged springs.
if ds.isEmpty then input.forall(_ != '#').toLong
// The input is empty but we expected some damaged springs,
// so this is not a valid arrangement.
else if input.isEmpty then 0L
else
def operationalCase(): Long =
// We can consume all following operational springs.
count2(input.tail.dropWhile(_ == '.'), ds)
def damagedCase(): Long =
// If the length of the input is less than the expected
// length of the damaged sequence, then this is not a
// valid arrangement.
if input.length < ds.head then 0L
else
// Split the input into a group of length ds.head and
// the rest.
val (group, rest) = input.splitAt(ds.head)
// If the group contains any operational springs, then
// this is not a a group of damaged springs, so this
// is not a valid arrangement.
if !group.forall(_ != '.') then 0L
// If the rest of the input is empty, then this is a
// valid arrangement only if the damaged sequence is
// the last one expected.
else if rest.isEmpty then ds.tail.isEmpty.toLong
// If we now have a damaged spring, then this is not
// the end of a damaged sequence as expected, and
// therefore not a valid arrangement.
else if rest.head == '#' then 0L
// Otherwise, we can continue with the rest of the
// input and the next expected damaged sequence.
else count2(rest.tail, ds.tail)
input.head match
case '?' => operationalCase() + damagedCase()
case '.' => operationalCase()
case '#' => damagedCase()
I find this implementation simpler and easier to understand than the previous one. Do you agree?
It naturally results in less calls, and the running time is improved:
example: 21 ( 69 calls, 36 ms)
personal: 7118 ( 12356 calls, 74 ms)
slow: 65536 ( 404 calls, 1 ms)
personal unfolded: 7030194981795 (235829 calls, 497 ms)
Local cache
We implemented memoization by using a global mutable map. What happens if we use a local, distinct one for each call to count
instead?
import scala.collection.mutable
def count2(input: List[Char], ds: List[Int]): Long =
val cache2 = mutable.Map.empty[(List[Char], List[Int]), Long]
def count2Cached(input: List[Char], ds: List[Int]): Long =
cache2.getOrElseUpdate((input, ds), count2Uncached(input, ds))
def count2Uncached(input: List[Char], ds: List[Int]): Long =
// ... same as before ...
// (but calling count2Cached instead of count2)
Even though this results to more calls to count2Uncached
, this actually improves the performance of the unfolded version, down to ~400 ms on my machine:
example: 21 ( 71 calls, 32 ms)
personal: 7118 ( 18990 calls, 67 ms)
slow: 65536 ( 425 calls, 3 ms)
personal unfolded: 7030194981795 ( 260272 calls, 407 ms)
Simplify cache keys
Because we will always consider the same sublists of input
and ds
for the lifetime of the cache, we can just use the lengths of these lists as keys:
import scala.collection.mutable
def count2(input: List[Char], ds: List[Int]): Long =
val cache2 = mutable.Map.empty[(Int, Int), Long]
def count2Cached(input: List[Char], ds: List[Int]): Long =
val key = (input.length, ds.length)
cache2.getOrElseUpdate(key, count2Uncached(input, ds))
// ... def count2Uncached as before
Which further reduces the running time of the unfolded version to ~320 ms on my machine:
example: 21 ( 71 calls, 33 ms)
personal: 7118 ( 18990 calls, 66 ms)
slow: 65536 ( 425 calls, 0 ms)
personal unfolded: 7030194981795 ( 260272 calls, 320 ms)
Simplify the cache structure
Our cache key is now just a pair of integers, so we don't need a Map
; an Array
can do the job just as well.
def count2(input: List[Char], ds: List[Int]): Long =
val dim1 = input.length + 1
val dim2 = ds.length + 1
val cache = Array.fill(dim1 * dim2)(-1L)
def count2Cached(input: List[Char], ds: List[Int]): Long =
val key = input.length * dim2 + ds.length
val result = cache(key)
if result == -1L then
val result = count2Uncached(input, ds)
cache(key) = result
result
else result
// ... def count2Uncached as before
This reduces the running time of the unfolded version down to ~200 ms on my machine:
example: 21 ( 71 calls, 27 ms)
personal: 7118 ( 18990 calls, 47 ms)
slow: 65536 ( 425 calls, 0 ms)
personal unfolded: 7030194981795 ( 260272 calls, 201 ms)
Inline helper functions
We have used helper functions to structure the implementation of count2
. To avoid the calls overhead, we can use Scala 3's inline
keyword.
After adding the inline
modifier to count2Cached
, operationalCase
and damagedCase
, the running time of the unfolded version is reduced to ~140 ms on my machine:
example: 21 ( 71 calls, 28 ms)
personal: 7118 ( 18990 calls, 50 ms)
slow: 65536 ( 425 calls, 0 ms)
personal unfolded: 7030194981795 ( 260272 calls, 137 ms)
Further optimizations
- Using a different data structure for the input and the damaged sequence, for example an
IArray
instead of aList
, and indexing into it instead of usingsplitAt
,tail
orhead
methods would probably improve the performance further, but would be more verbose and less idiomatic. - Parallelizing
countAllUnfolded
did not result in any performance improvement on my machine. It might on larger inputs.
Can you think of other optimizations that could improve the performance of this code without sacrificing readability?
Final Code
/** Entry point for part 1. */
def part1(input: String): Unit =
println(countAll(input))
/** Sums `countRow` over all rows in `input`. */
def countAll(input: String): Long =
input.split("\n").map(countRow).sum
/** Counts all of the different valid arrangements
* of operational and broken springs in the given row.
*/
def countRow(input: String): Long =
val Array(conditions, damagedCounts) = input.split(" ")
count2(
conditions.toList,
damagedCounts.split(",").map(_.toInt).toList
)
extension (b: Boolean) private inline def toLong: Long =
if b then 1L else 0L
def count2(input: List[Char], ds: List[Int]): Long =
val dim1 = input.length + 1
val dim2 = ds.length + 1
val cache = Array.fill(dim1 * dim2)(-1L)
inline def count2Cached(input: List[Char], ds: List[Int]): Long =
val key = input.length * dim2 + ds.length
val result = cache(key)
if result == -1L then
val result = count2Uncached(input, ds)
cache(key) = result
result
else result
def count2Uncached(input: List[Char], ds: List[Int]): Long =
// We've seen all expected damaged sequences.
// The arrangement is therefore valid only if the
// input does not contain damaged springs.
if ds.isEmpty then input.forall(_ != '#').toLong
// The input is empty but we expected some damaged springs,
// so this is not a valid arrangement.
else if input.isEmpty then 0L
else
inline def operationalCase(): Long =
// Operational case: we can consume all operational
// springs to get to the next choice.
count2Cached(input.tail.dropWhile(_ == '.'), ds)
inline def damagedCase(): Long =
// If the length of the input is less than the expected
// length of the damaged sequence, then this is not a
// valid arrangement.
if input.length < ds.head then 0L
else
// Split the input into a group of length ds.head and
// the rest.
val (group, rest) = input.splitAt(ds.head)
// If the group contains any operational springs, then
// this is not a a group of damaged springs, so this
// is not a valid arrangement.
if !group.forall(_ != '.') then 0L
// If the rest of the input is empty, then this is a
// valid arrangement only if the damaged sequence
// is the last one expected.
else if rest.isEmpty then ds.tail.isEmpty.toLong
// If we now have a damaged spring, then this is not the
// end of a damaged sequence as expected, and therefore
// not a valid arrangement.
else if rest.head == '#' then 0L
// Otherwise, we can continue with the rest of the input
// and the next expected damaged sequence.
else count2Cached(rest.tail, ds.tail)
input.head match
case '?' => operationalCase() + damagedCase()
case '.' => operationalCase()
case '#' => damagedCase()
count2Cached(input, ds)
end count2
/** Entry point for part 2 */
def part2(input: String): Unit =
println(countAllUnfolded(input))
def countAllUnfolded(input: String): Long =
input.split("\n").map(unfoldRow).map(countRow).sum
def unfoldRow(input: String): String =
val Array(conditions, damagedCounts) =
input.split(" ")
val conditionsUnfolded =
(0 until 5).map(_ => conditions).mkString("?")
val damagedCountUnfolded =
(0 until 5).map(_ => damagedCounts).mkString(",")
f"$conditionsUnfolded $damagedCountUnfolded"
Solutions from the community
- Solution by Rui Alves
- Solution by @natsukagami
- Solution by Marconi Lanna
- Solution by Spamegg
- Solution by Thanh Le
- Solution by Seth Tisue
- Solution by g.berezin
- Solution by Jamie Thompson
- Solution by Paweł Cembaluk
Share your solution to the Scala community by editing this page. You can even write the whole article! See here for the expected format