# Day 11: Cosmic Expansion

by @natsukagami

## Puzzle description

https://adventofcode.com/2023/day/11

### Puzzle Summary

We are given a grid of `.`

and `#`

. We would like to find the sum of distances between all *pairs* of `#`

in the grid.
The distance between two `#`

is defined as the number of vertical and horizontal steps to go from one to the other.
One caveat: each row and each column that has no `#`

actually represents `k`

empty rows/columns respectively.

- In part 1,
`k = 2`

. - In part 2,
`k = 1_000_000`

.

## Solution Summary

by @natsukagami

We start by parsing the input into a board structure (a `Seq[String]`

, with each string *representing a row*).

`val board = readInput().linesIterator.toSeq`

First, it is clear to us that the distance we are looking for is the Manhattan Distance between
two `#`

in the grid.
We can simplify the distance formula by, given two coordinates of the `#`

s, as

`case class Coord(row: Int, col: Int)`

def distance(a: Coord, b: Coord) = (a.row - b.row).abs + (a.col - b.col).abs

Note that the distance in the row and column coordinates are *independent*, we can calculate them separately and add them back together at the end.
Furthermore, since the calculation for rows and columns are *exactly the same*, we can simply write code that deals with one coordinate and use
`.transpose`

on the board to flip the row/column coordinates of all points, to calculate the other coordinate.

With this fact, from now on we can only talk about calculating the distances in the row coordinate!

### When calculating row distance, columns don't matter

Let's look at the formula for the row distance:

`def rowDistance(a: Coord, b: Coord) = (a.row - b.row).abs`

Note that `col`

is never mentioned! This means, we can simply treat all `#`

s with the same `row`

coordinate *exactly the same*!
If we have `n`

points with `row = x`

and `m`

points with `row = y`

, the *total* row distance of all these points can simply be
calculated as `n * m * (x - y).abs`

.

Since it is no longer important to keep the board as-is, we can collapse it to just the *count* of `#`

s for each row.

`val countByRow = board`

.map(row => row.count('#'))

.toArray // get O(1) indexing

### A cubic solution: calculate for each pair of rows!

Now, since the empty rows are expanded, we cannot use the old row distance formula `(a.row - b.row).abs`

.
Instead, we have to *count* the number of empty rows in between:

`def rowDistance(xRow: Int, yRow: Int) =`

// assuming xRow < yRow.

val distanceForOnePair = (xRow + 1 to yRow - 1)

.map(row =>

if countByRow(row) == 0

then k /* expanded */

else 1L /* not expanded */

)

.sum

// the total distance is counted for every pair

distanceForOnePair * countByRow(xRow) * countByRow(yRow)

We can simply go through every pair of rows to perform this calculation:

`val result = {`

for i <- 0 until countByRow.length

j <- 0 until i

yield rowDistance(j, i)

}.sum

This has running time complexity `O(rows^3)`

, however it should suffice for the input in AoC (which gives you a grid of <150 rows and columns).
However, we can do better! Read on to see how we optimize away the redundant calculations.

### Reduce to quadratic: Memoize the calculations with prefix sums!

Let's look at the formula for `distanceForOnePair`

again:

`val distanceForOnePair = (xRow + 1 to yRow - 1)`

.map(row =>

if countByRow(row) == 0

then k /* expanded */

else 1L /* not expanded */

)

.sum

Note that the map function actually is a *pure* function based on the row index, and therefore we can just pre-calculate it. Not yet a
reduction in running time, but our code is clearer.

`// outside of `rowDistance`...`

val expandedSize: Array[Long] = countByRow.map(if _ == 0 then k else 1L)

// inside of `rowDistance`...

val distanceForOnePair = (xRow + 1 to yRow - 1).map(expandedSize(_)).sum

At this point we can leverage prefix sums to make getting a sum of a range of elements a constant operation...

`// outside of `rowDistance`...`

val expandedSize: Array[Long] =

countByRow.map(if _ == 0 then k else 1L)

// expandedSizePrefix(i) = expandedSize(0) + ... + expandedSize(i-1)

val expandedSizePrefix = expandedSize.scan(0L)(_ + _)

// inside of `rowDistance`...

val distanceForOnePair =

expandedSizePrefix(yRow) - expandedSizePrefix(xRow)

And we have just lowered the running time of the solution to `O(rows^2)`

, by making `rowDistance`

constant-time!

Here is the full code.

`def solve(input: String, expand: Int) =`

val board = input.linesIterator.toSeq

val countByRow = board

.map(row => row.count(_ == '#'))

.toArray // get O(1) indexing

val countByCol = board.transpose // rotate the board!

.map(col => col.count(_ == '#'))

.toArray

allRowDistances(expand, countByRow)

+ allRowDistances(expand, countByCol)

end solve

def part1(input: String) = solve(input, expand = 2)

def part2(input: String) = solve(input, expand = 1_000_000)

def allRowDistances(k: Int, counts: Array[Int]): Long =

val expandedSize: Array[Long] = counts.map(v => if v == 0 then k else 1L)

// expandedSizePrefix(i) = expandedSize(0) + ... + expandedSize(i-1)

val expandedSizePrefix = expandedSize.scan(0L)(_ + _)

def rowDistance(xRow: Int, yRow: Int): Long =

val distanceForOnePair = expandedSizePrefix(yRow) - expandedSizePrefix(xRow)

distanceForOnePair * counts(xRow) * counts(yRow)

(for i <- 0 until counts.length

j <- 0 until i

yield rowDistance(j, i)).sum

Now, this is enough for the puzzle, as reading the input itself is `O(rows * col)`

. But ignoring that, can we do better? Hint: yes.
Let's go on the optimization train.

### Approaching the linear summit: more prefix sums

Going further requires us to inline the definition of `rowDistance`

.
Let us apply some mathematical transformations and do some equational reasoning!

`val result = {`

for i <- 0 until counts.length

j <- 0 until i

yield

counts(i) * counts(j) * (expandedSizePrefix(i) - expandedSizePrefix(j))

}.sum

Let's regroup the `for`

loop a bit:

`val result = {`

for i <- 0 until counts.length

yield counts(i) * {

(for j <- 0 until i yield counts(j) * expandedSizePrefix(i) - counts(j) * expandedSizePrefix(j)).sum

/* = */ expandedSizePrefix(i) * (for j <- 0 until i yield counts(j)).sum - (for j <- 0 until i yield counts(j) * expandedSizePrefix(j)).sum

}.sum

}.sum

We can see the `for j <- 0 until i`

pattern here, which means a prefix sum can be utilized again!

`val countsSum = counts.scan(0L)(_ + _)`

val countsTimeExpandedSizePSum = counts

.lazyZip(expandedSizePrefix)

.map(_.toLong * _) // multiplied together

.scan(0L)(_ + _) // create a prefix sum

// and now we have

val result = {

for i <- 0 until counts.length

yield counts(i) * {

expandedSizePrefix(i) * countsSum(i) - countsTimeExpandedSizePSum(i)

}

}

Voila, linear time complexity!

### Linear time with Recursions, or "Sweep Line"

Now, the previous approach requires some math and a lot of arrays. Can we do it in a more Scala-like way, with some (tail)
recursion? Enter *sweep line algorithm*.

Same idea as before: we calculate the row distance and column distance separately.
Let us rewrite the row *distance* from a point in row `j`

to reach the (end of the expanded) row `i`

as a recursive formula:

`def distance(j, i) =`

if j == i then 0 // same row

else if counts(i) == 0 then distance(j, i-1) + k // i was expanded

else distance(j, i-1) + 1 // i was not expanded

From a *single* point in row `i`

, the distance to *all* points in rows *before* `i`

would be

`def totalDistance(i) =`

(for j <- 0 until i yield counts(j) * distance(j, i)).sum

Let's write this in terms of recursion on `i`

!

`def totalDistance(i) =`

if i == 0 then 0

else

(for j <- 0 to i yield counts(j) * distance(j, i)).sum

/* = */ (for j <- 0 to i-1 yield counts(j) * distance(j, i)).sum + counts(i) * distance(i, i) /* this part is always 0! */

/* = */ (for j <- 0 to i-1

yield count(j) *

if j == i then 0 /* never happens */

else if counts(i) == 0 then distance(j, i-1) + k // i was expanded

else distance(j, i-1) + 1 // i was not expanded

).sum

/* = */ (for j <- 0 to i-1

yield count(j) * distance(j, i-1) + count(j) * {

if counts(i) == 0 then k // i was expanded

else 1 // i was not expanded

}).sum

/* = */ (for j <- 0 to i-1

yield count(j) * distance(j, i-1)

).sum + // this is just totalDistance(i-1)!

(for j <- 0 to i-1

yield count(j) * {

// this is independent of j!

if counts(i) == 0 then k // i was expanded

else 1 // i was not expanded

}).sum

/* = */ totalDistance(i-1) + (for j <- 0 to i-1 yield count(j)).sum * (if counts(i) == 0 then k else 1)

Which is *almost* a fully linear recursive formula, except we also need to track the **total number of points** coming *before* `i`

!
This is fine, we shall do it in our recursive function...

To calculate the row distance, we *sweep* through the points top-down.
Simulating `totalDistance`

, we track `totalDistance(i)`

and the sum `counts(0) + ... + counts(i-1)`

as we go through the `counts`

sequence (now a list!).

`@scala.annotation.tailrec`

def loop(

countsSum: Long, // counts(0) + ... + counts(i-1)

totalDistance: Long, // totalDistance(i-1)

accum: Long, // the accumulated answer

)(

counts: List[Int], // our count-by-row list

): Long = counts match

case Nil => accum // done!

case head :: tail => // head is counts(i), tail is counts(i+1 .. end)

val newCountsSum = countsSum + head // counts(0) + ... + counts(i)

val newTotalDist = totalDistance + countsSum * (if head == 0 then k else 1) // follow the formula!

val distanceToPointsHere = head * newTotalDist

loop(newCountsSum, newTotalDist, accum + distanceToPointsHere)(tail)

... and this is the solution presented in our repo!

### Where to expand on the problem?

Here are some ideas that I think would be interesting to look into:

- Bigger inputs: what if the space is an incredibly large grid, but the number of
`#`

s are sparse (for example, about 100000 points in a 10^9-sized grid)? Can we leverage the same technique to achieve an efficient counting algorithm? - Non-linear
*k*: What if instead of expanding empty rows/columns by a constant*k*, we expand the topmost empty row by 1, the second empty row by 2 and so on... Same with columns. Can we still keep the counting linear? - (Squared) Euclidean distance: what if our distance is the square of the
*actual*distance between the`#`

s (i.e.`(a.x - b.x)^2 + (a.y - b.y)^2`

)? We*should*be able to*still*keep the counting algorithm linear with some math!

## Solutions from the community

- Solution by Spamegg
- Solution by Raymond Dodge
- Solution by Rui Alves
- Solution by Thanh Le
- Solution by Seth Tisue
- Solution by jnclt
- Solution by g.berezin
- Solution by Marconi Lanna
- Solution by Philippus Baalman
- Solution by Joel Edwards
- Solution by Michael Pilquist
- Solution by Jamie Thompson
- Solution by Paweł Cembaluk

Share your solution to the Scala community by editing this page. You can even write the whole article! See here for the expected format