Day 4: Ceres Search
by @bishabosha
Puzzle description
https://adventofcode.com/2024/day/4
Solution Summary
Treat the input as a 2D grid, (whose elements are only X
, M
, A
, or S
).
Part 1: iterate through each point in the grid.
For each point and for each of the eight directions, try to construct the word "XMAS"
starting from the current point, while handling out of bounds.
Because each origin point is visited once there is no need to handle duplicates.
Count the total matched words.
Part 2: iterate through each point in the grid.
If the current point is 'A'
, then check if it is the middle of a X-MAS
:
- transform the point to each of its four corners and try to construct the word
"MAS"
travelling in the opposite direction of the translation, - a correct X-MAS will have two such corners for a single origin
'A'
.
Again, because each origin point is visited once there is no need to handle duplicates.
Count the total 'A'
where this is true.
Part 1
First, the puzzle description does not say explicitly, but the input is a square grid, i.e. N
lines of N
characters.
It will help us explore the data if it is represented as a 2D space:
type Grid = IArray[IArray[Char]]
def parse(input: String): Grid =
IArray.from(
input.linesIterator.map(IArray.from)
)
IArray
is an efficient immutable array of fixed size, with fast random access.
e.g. the string
ABC\n
DEF\n
GHI\n
would be represented as
IArray(
IArray('A', 'B', 'C'),
IArray('D', 'E', 'F'),
IArray('G', 'H', 'I')
)
In the 2D grid, you have a coordinate system where y
axis corresponds to rows, i.e. entries of the outer array,
and x
axis corresponds to columns, i.e. entries of the nested arrays. The origin y=0,x=0
is at the top left.
A point y=2,x=1
(i.e. third row, second column) is then accessed with grid(2)(1)
, giving 'H'
in this case.
Finding the XMAS
For part 1, we are tasked with finding occurrences of the string "XMAS"
, which can begin at any point, and the letters can be in any straight line in any of the 8 directions:
case class Dir(dy: Int, dx: Int)
val dirs = IArray(
Dir(dy = -1, dx = 0), // up
Dir(dy = 0, dx = 1), // right
Dir(dy = 1, dx = 0), // down
Dir(dy = 0, dx = -1), // left
Dir(dy = -1, dx = 1), // up-right
Dir(dy = 1, dx = 1), // down-right
Dir(dy = 1, dx = -1), // down-left
Dir(dy = -1, dx = -1) // up-left
)
in this case, we represent a direction as a vector in 2 dimensions: dy
(the change in y
), and dx
(the change in x
).
So we need a capability to build a string of characters in any direction, traversing the grid.
We can do this with an Iterator
, which can generate all the potential characters while traversing in a direction
until we reach the end of the grid.
First the test for grid out-of-boundary:
def boundCheck(x: Int, y: Int, grid: Grid): Boolean =
x >= 0 && x < grid.length && y >= 0 && y < grid(0).length
Next the iterator, build with Iterator.unfold
, which will continue iterating some internal state, until a final state is reached:
def scanner(x: Int, y: Int, dir: Dir, grid: Grid): Iterator[Char] =
Iterator.unfold((y, x)): (y, x) =>
Option.when(boundCheck(x, y, grid))(grid(y)(x) -> (y + dir.dy, x + dir.dx))
the unfold method takes an initial state (y,x)
, and a function to produce an optional pair of the iterators value,
and the next state. In the case above, we stop when boundCheck
fails, return the letter at grid(y)(x)
paired with
the next position obtained by translating the point by the direction.
Next, we need a way to compare the string along a direction with a target word, and exit early as soon as the match would fail. We can use the corresponds
method on iterators to compare each element with another value, and exit as soon as an element either side doesnt match:
def scanString(target: String)(x: Int, y: Int, dir: Dir, grid: Grid): Boolean =
scanner(x, y, dir, grid).take(target.length).corresponds(target)(_ == _)
val scanXMAS = scanString("XMAS")
Now we have a way to find a string in a grid, we should iterate through each point in the grid, by using Iterator.tabulate
,
and count all the possible "XMAS"
strings starting from that point:
def totalXMAS(grid: Grid): Int =
Iterator
.tabulate(grid.size, grid.size): (y, x) =>
dirs.count(dir => scanXMAS(x, y, dir, grid))
.flatten
.sum
Note that Iterator.tabulate
with two arguments creates an Iterator[Iterator[T]]
, so use .flatten
to join the results together.
now we have the full solution for part 1:
def part1(input: String): Int =
totalXMAS(parse(input))
Part 2
In part 2, we have a slight change, firstly, we will only be looking for "MAS"
string:
val scanMAS = scanString("MAS")
next, in our iteration through the grid, we will only be interested in the number of positions that are 'A'
, that are
also the center of a valid X-MAS
formation. To do this, we will inspect for "MAS"
starting in each of the four surrounding corners.
e.g. if we picture the four corners in the grid like this,
1-2
-A-
4-3
Then we will look for "MAS"
starting from 1
towards 3
, from 2
towards 4
, from 3
towards 1
, and from 4
towards 2
.
To do that we will need two Dirs for each case, one to translate the current point to the corner, and then the opposite direction to scan in. Here is how you can represent that:
val UpRight = dirs(4)
val DownRight = dirs(5)
val DownLeft = dirs(6)
val UpLeft = dirs(7)
val dirsMAS = IArray(
UpLeft -> DownRight, // 1 -> 3
UpRight -> DownLeft, // 2 -> 4
DownRight -> UpLeft, // 3 -> 1
DownLeft -> UpRight // 4 -> 2
)
Then to check if a single point is X-MAS
, use the following code:
def isMAS(x: Int, y: Int, grid: Grid): Boolean =
grid(y)(x) match
case 'A' =>
val seen = dirsMAS.count: (transform, dir) =>
scanMAS(x + transform.dx, y + transform.dy, dir, grid)
seen > 1
case _ => false
i.e. when the point is 'A'
, then for each of the four corners, translate the point to the corner which will be the origin point to scan for "MAS"
. Then pass along the scanning direction and the grid. For a valid X-MAS
, two of the corners will be origin points for the word.
Putting it all together, we then again iterate through the grid, now only counting the points where isMas
is true.
def part2(input: String): Int =
totalMAS(parse(input))
def totalMAS(grid: Grid): Int =
Iterator
.tabulate(grid.size, grid.size): (y, x) =>
if isMAS(x, y, grid) then 1 else 0
.flatten
.sum
Final Code
def part1(input: String): Int =
totalXMAS(parse(input))
type Grid = IArray[IArray[Char]]
def parse(input: String): Grid =
IArray.from(
input.linesIterator.map(IArray.from)
)
case class Dir(dy: Int, dx: Int)
val dirs = IArray(
Dir(dy = -1, dx = 0), // up
Dir(dy = 0, dx = 1), // right
Dir(dy = 1, dx = 0), // down
Dir(dy = 0, dx = -1), // left
Dir(dy = -1, dx = 1), // up-right
Dir(dy = 1, dx = 1), // down-right
Dir(dy = 1, dx = -1), // down-left
Dir(dy = -1, dx = -1) // up-left
)
def boundCheck(x: Int, y: Int, grid: Grid): Boolean =
x >= 0 && x < grid.length && y >= 0 && y < grid(0).length
def scanner(x: Int, y: Int, dir: Dir, grid: Grid): Iterator[Char] =
Iterator.unfold((y, x)): (y, x) =>
Option.when(boundCheck(x, y, grid))(grid(y)(x) -> (y + dir.dy, x + dir.dx))
def scanString(target: String)(x: Int, y: Int, dir: Dir, grid: Grid): Boolean =
scanner(x, y, dir, grid).take(target.length).corresponds(target)(_ == _)
val scanXMAS = scanString("XMAS")
def totalXMAS(grid: Grid): Int =
Iterator
.tabulate(grid.size, grid.size): (y, x) =>
dirs.count(dir => scanXMAS(x, y, dir, grid))
.flatten
.sum
def part2(input: String): Int =
totalMAS(parse(input))
val scanMAS = scanString("MAS")
val UpRight = dirs(4)
val DownRight = dirs(5)
val DownLeft = dirs(6)
val UpLeft = dirs(7)
val dirsMAS = IArray(
UpLeft -> DownRight,
UpRight -> DownLeft,
DownRight -> UpLeft,
DownLeft -> UpRight
)
def isMAS(x: Int, y: Int, grid: Grid): Boolean =
grid(y)(x) match
case 'A' =>
val seen = dirsMAS.count: (transform, dir) =>
scanMAS(x + transform.dx, y + transform.dy, dir, grid)
seen > 1
case _ => false
def totalMAS(grid: Grid): Int =
Iterator
.tabulate(grid.size, grid.size): (y, x) =>
if isMAS(x, y, grid) then 1 else 0
.flatten
.sum
Solutions from the community
- Solution by Jamie Thompson
- Solution by Philippus Baalman
- Solution by Henryk Česnolovič
- Solution by YannMoisan
- Solution by Raphaël Marbeck
- Solution by Spamegg
- Solution by jnclt
- Solution by scarf
- Solution by nichobi
- Solution by Maciej Gorywoda
- Solution by itsjoeoui
- Solution by Guillaume Vandecasteele
- Solution by Roland Tritsch
- Solution of Jan Boerman
- Solution by Georgi Krastev
- Solution by Joshua Portway
- Solution by Bulby
- Solution by Paweł Cembaluk
Share your solution to the Scala community by editing this page. You can even write the whole article! See here for the expected format