Skip to main content

Day 16: Packet Decoder

by @tgodzik

Puzzle description

https://adventofcode.com/2021/day/16

Part1: You've got mail!

It seems that we can split our problem into two parts. First, we need to parse the example into structures that we can later use to calculate our results.

Let's start with defining the data structures to use:

enum Packet(version: Int, typeId: Int):
case Literal(version: Int, value: Long) extends Packet(version, 4)
case Operator(version: Int, typeId: Int, exprs: List[Packet]) extends Packet(version, typeId)

Packet.Literal will represent simple literal packets, that contain only a value. We are using Long, just in case of large integer numbers based on the experience with previous Advent of Code puzzles.

Packet.Operator will represent all the other operators that can contain other packets.

Now we need to map our input to these structures.

Let's start by mapping the hexadecimal input to a list of chars that we can analyze easier when checking for packets:

val hexadecimalMapping =
Map(
'0' -> "0000",
'1' -> "0001",
'2' -> "0010",
'3' -> "0011",
'4' -> "0100",
'5' -> "0101",
'6' -> "0110",
'7' -> "0111",
'8' -> "1000",
'9' -> "1001",
'A' -> "1010",
'B' -> "1011",
'C' -> "1100",
'D' -> "1101",
'E' -> "1110",
'F' -> "1111"
)

def parse(input: String) =
val byteInput = input.toList.flatMap(hex => hexadecimalMapping(hex).toCharArray)
...

This will produce the input that we have seen in the puzzle description:

110100101111111000101000

Since we've got that we can start defining our function for decoding packets. We first define parsing of elements common to all the packets, which is the version and the type ID. Based on the type id we can see how we need to parse the rest of the data.

type BinaryData = List[Char]

// helper function to read binary data 01101 to decimal 13
def toInt(chars: BinaryData): Int =
Integer.parseInt(chars.mkString, 2)

// helper function to read binary data 01101 to decimal 13, but in a Long format
def toLong(chars: BinaryData): Long =
java.lang.Long.parseLong(chars.mkString, 2)

def readLiteralBody(input: BinaryData): (Long, BinaryData) = ???
def readOperatorBody(input: BinaryData): (List[Packet], BinaryData) = ???

def decodePacket(packet: BinaryData): (Packet, BinaryData) =
val (versionBits, rest) = packet.splitAt(3)
val version = toInt(versionBits)
val (typeBits, body) = rest.splitAt(3)
val tpe = toInt(typeBits)

tpe match
case 4 =>
val (value, remaining) = readLiteralBody(body, Nil)
(Packet.Literal(version, value), remaining)
case otherTpe =>
val (values, remaining) = readOperatorBody(body)
(Packet.Operator(version, otherTpe, values), remaining)
end match
end decodePacket

We use the function splitAt, which gives us the ability to split the input into the part that we need, for example 3 bits for version, and the rest of the packet data. This way we can read the version and typeId, pattern match on the latter and use proper logic for reading in each case. We can then create our new structures. We also defined a helper type BinaryData, since we will be using it throughout the puzzle. What remains is defining readLiteralBody and readOperatorBody. You might notice that we return additional BinaryData from each function. This is because we will be later able to use it to analyze the output further in a recursive manner, but we'll get back to it.

Let's start with the first undefined function readLiteralBody. In the description we read that the body of the literal consists of segments of 5 bits, where the last segment will start with 0 and all the others with 1. The remaining 4 bits can be used to construct a number. We can create a recursive function that will handle it perfectly!

@tailrec
def readLiteralBody(tail: BinaryData, numAcc: BinaryData): (Long, BinaryData) =
val (num, rest) = tail.splitAt(5)
if num(0) == '1' then readLiteralBody(rest, numAcc.appendedAll(num.drop(1)))
else
val bits = numAcc.appendedAll(num.drop(1))
(toLong(bits), rest)
end readLiteralBody

In each step we read 5 bits from the input and check if we should finish. If the first bit is 0 then we know that we can append the last 5 bits and return the current result. In case the first bit is 1, we need to repeat the step once more on the remaining bits.

The harder part will be defining readOperatorBody since we know that operator packets can contain other packets and those packets can also be operators! This means we will need to apply a recursive approach:

def readOperatorBody(current: BinaryData): (List[Packet], BinaryData) =
val (lenId, rest) = current.splitAt(1)

@tailrec
def readMaxBits(
current: BinaryData,
remaining: Int,
acc: List[Packet]
): (List[Packet], BinaryData) =
if remaining == 0 then (acc, current)
else
val (newExpr, rest) = decodePacket(current)
readMaxBits(rest, remaining - (current.size - rest.size), acc :+ newExpr)

@tailrec
def readMaxPackets(
current: BinaryData,
remaining: Int,
acc: List[Packet]
): (List[Packet], BinaryData) =
if remaining == 0 then (acc, current)
else
val (newExpr, rest) = decodePacket(current)
readMaxPackets(rest, remaining - 1, acc :+ newExpr)

// read based on length
if lenId(0) == '0' then
val (size, packets) = rest.splitAt(15)
readMaxBits(packets, toInt(size), Nil)
// read based on number of packages
else
val (size, packets) = rest.splitAt(11)
readMaxPackets(packets, toInt(size), Nil)
end match
end readOperatorBody

In the above function we first check the first bit of the operator body, which tells us how we should check the rest of the body.

  • if the bit is 0 it means that the next 15 bits can be turned into a number, that will define how many of the further bits are the subpackets of the operator.

  • if the bit is 1 it means that the next 11 bits can be turned into a number, that will define how many subpackets should belong to the operator.

We defined two helper recursive functions readMaxBits and readMaxPackets which will check if the stopping condition (either max bits read or max packets read) is achieved or read a new packet using recursively the decodePacket function otherwise. At the end they will both return a list of packets, that we can later use to put into the operator packet, and the remaining input that we might need to check for more packets.

This should already allow us to create a full structure and what remains is adding a function that can add up all the versions. We can add that function to the Packet enum and sum it all recursively.

  def versionSum: Int =
this match
case Literal(version, _) => version
case Operator(version, exprs, _) => version + exprs.map(_.versionSum).sum

That's it! We should be able to solve the part 1.

Full solution

package day16

import scala.util.Using
import scala.io.Source
import scala.annotation.tailrec

@main def part1(): Unit =
println(s"The solution is ${part1(readInput())}")

def readInput(): String =
Using.resource(Source.fromFile("input/day16"))(_.mkString)

val hexadecimalMapping =
Map(
'0' -> "0000",
'1' -> "0001",
'2' -> "0010",
'3' -> "0011",
'4' -> "0100",
'5' -> "0101",
'6' -> "0110",
'7' -> "0111",
'8' -> "1000",
'9' -> "1001",
'A' -> "1010",
'B' -> "1011",
'C' -> "1100",
'D' -> "1101",
'E' -> "1110",
'F' -> "1111"
)

enum Packet(version: Int, typeId: Int):
case Literal(version: Int, value: Long) extends Packet(version, 4)
case Operator(version: Int, typeId: Int, exprs: List[Packet]) extends Packet(version, typeId)
def versionSum: Int =
this match
case Literal(version, _) => version
case Operator(version, _, exprs) => version + exprs.map(_.versionSum).sum

type BinaryData = List[Char]

// helper function to read binary data 01101 to decimal 13
def toInt(chars: BinaryData): Int =
Integer.parseInt(chars.mkString, 2)

// helper function to read binary data 01101 to decimal 13, but in a Long format
def toLong(chars: BinaryData): Long =
java.lang.Long.parseLong(chars.mkString, 2)

@tailrec
def readLiteralBody(tail: BinaryData, numAcc: BinaryData): (Long, BinaryData) =
val (num, rest) = tail.splitAt(5)
if num(0) == '1' then readLiteralBody(rest, numAcc.appendedAll(num.drop(1)))
else
val bits = numAcc.appendedAll(num.drop(1))
(toLong(bits), rest)
end readLiteralBody

def readOperatorBody(current: BinaryData): (List[Packet], BinaryData) =
val (lenId, rest) = current.splitAt(1)

@tailrec
def readMaxBits(
current: BinaryData,
remaining: Int,
acc: List[Packet]
): (List[Packet], BinaryData) =
if remaining == 0 then (acc, current)
else
val (newExpr, rest) = decodePacket(current)
readMaxBits(rest, remaining - (current.size - rest.size), acc :+ newExpr)

@tailrec
def readMaxPackets(
current: BinaryData,
remaining: Int,
acc: List[Packet]
): (List[Packet], BinaryData) =
if remaining == 0 then (acc, current)
else
val (newExpr, rest) = decodePacket(current)
readMaxPackets(rest, remaining - 1, acc :+ newExpr)

lenId match
// read based on length
case List('0') =>
val (size, packets) = rest.splitAt(15)
readMaxBits(packets, toInt(size), Nil)

// read based on number of packages
case _ =>
val (size, packets) = rest.splitAt(11)
readMaxPackets(packets, toInt(size), Nil)
end match
end readOperatorBody

def decodePacket(input: BinaryData): (Packet, BinaryData) =
val (versionBits, rest) = input.splitAt(3)
val version = toInt(versionBits)
val (typeBits, body) = rest.splitAt(3)
val tpe = toInt(typeBits)

tpe match
case 4 =>
val (value, remaining) = readLiteralBody(body, Nil)
(Packet.Literal(version, value), remaining)
case otherTpe =>
val (values, remaining) = readOperatorBody(body)
(Packet.Operator(version, otherTpe, values), remaining)
end match
end decodePacket

def parse(input: String) =
val number = input.toList.flatMap(hex => hexadecimalMapping(hex).toCharArray)
val (operator, _) = decodePacket(number)
operator

def part1(input: String) =
val packet = parse(input)
packet.versionSum

Part 2: The Elven calculus

Turns out that operator packets are actual mathematical operators and we can use the type ID to distinguish them!

We need to improve our structure to better show the different mathematical operators. For that we define additional enum cases instead of a single Operator case.

enum Packet(version: Int, typeId: Int):
case Sum(version: Int, exprs: List[Packet]) extends Packet(version, 0)
case Product(version: Int, exprs: List[Packet]) extends Packet(version, 1)
case Minimum(version: Int, exprs: List[Packet]) extends Packet(version, 2)
case Maximum(version: Int, exprs: List[Packet]) extends Packet(version, 3)
case Literal(version: Int, literalValue: Long) extends Packet(version, 4)
case GreaterThan(version: Int, lhs: Packet, rhs: Packet) extends Packet(version, 5)
case LesserThan(version: Int, lhs: Packet, rhs: Packet) extends Packet(version, 6)
case Equals(version: Int, lhs: Packet, rhs: Packet) extends Packet(version, 7)

We will also need to modify the way we create these operators:

So instead of

      val (values, remaining) = readOperatorBody(body)
(Packet.Operator(version, otherTpe, values), remaining)

we will need to write:

     val (values, remaining) = readOperatorBody(body)
otherTpe match
case 0 => (Packet.Sum(version, values), remaining)
case 1 => (Packet.Product(version, values), remaining)
case 2 => (Packet.Minimum(version, values), remaining)
case 3 => (Packet.Maximum(version, values), remaining)
case 5 => (Packet.GreaterThan(version, values(0), values(1)), remaining)
case 6 => (Packet.LesserThan(version, values(0), values(1)), remaining)
case 7 => (Packet.Equals(version, values(0), values(1)), remaining)

This makes our structure accurately show the mathematical computation that is constructed from the packets. The last remaining step is to create a function that will calculate the equation. We can do it similarly to the versionsSum function in the previous part:

  def value: Long =
this match
case Sum(version, exprs) => exprs.map(_.value).sum
case Product(version, exprs) => exprs.map(_.value).reduce(_ * _)
case Minimum(version, exprs) => exprs.map(_.value).min
case Maximum(version, exprs) => exprs.map(_.value).max
case Literal(version, value) => value
case GreaterThan(version, lhs, rhs) => if lhs.value > rhs.value then 1 else 0
case LesserThan(version, lhs, rhs) => if lhs.value < rhs.value then 1 else 0
case Equals(version, lhs, rhs) => if lhs.value == rhs.value then 1 else 0

Full solution

package day16

import scala.util.Using
import scala.io.Source
import scala.annotation.tailrec

@main def part1(): Unit =
println(s"The solution is ${part1(readInput())}")

@main def part2(): Unit =
println(s"The solution is ${part2(readInput())}")

def readInput(): String =
Using.resource(Source.fromFile("input/day16"))(_.mkString)

val hexadecimalMapping =
Map(
'0' -> "0000",
'1' -> "0001",
'2' -> "0010",
'3' -> "0011",
'4' -> "0100",
'5' -> "0101",
'6' -> "0110",
'7' -> "0111",
'8' -> "1000",
'9' -> "1001",
'A' -> "1010",
'B' -> "1011",
'C' -> "1100",
'D' -> "1101",
'E' -> "1110",
'F' -> "1111"
)

/*
* Structures for all possible operators
*/
enum Packet(version: Int, typeId: Int):
case Sum(version: Int, exprs: List[Packet]) extends Packet(version, 0)
case Product(version: Int, exprs: List[Packet]) extends Packet(version, 1)
case Minimum(version: Int, exprs: List[Packet]) extends Packet(version, 2)
case Maximum(version: Int, exprs: List[Packet]) extends Packet(version, 3)
case Literal(version: Int, literalValue: Long) extends Packet(version, 4)
case GreaterThan(version: Int, lhs: Packet, rhs: Packet) extends Packet(version, 5)
case LesserThan(version: Int, lhs: Packet, rhs: Packet) extends Packet(version, 6)
case Equals(version: Int, lhs: Packet, rhs: Packet) extends Packet(version, 7)

def versionSum: Int =
this match
case Sum(version, exprs) => version + exprs.map(_.versionSum).sum
case Product(version, exprs) => version + exprs.map(_.versionSum).sum
case Minimum(version, exprs) => version + exprs.map(_.versionSum).sum
case Maximum(version, exprs) => version + exprs.map(_.versionSum).sum
case Literal(version, value) => version
case GreaterThan(version, lhs, rhs) => version + lhs.versionSum + rhs.versionSum
case LesserThan(version, lhs, rhs) => version + lhs.versionSum + rhs.versionSum
case Equals(version, lhs, rhs) => version + lhs.versionSum + rhs.versionSum

def value: Long =
this match
case Sum(version, exprs) => exprs.map(_.value).sum
case Product(version, exprs) => exprs.map(_.value).reduce(_ * _)
case Minimum(version, exprs) => exprs.map(_.value).min
case Maximum(version, exprs) => exprs.map(_.value).max
case Literal(version, value) => value
case GreaterThan(version, lhs, rhs) => if lhs.value > rhs.value then 1 else 0
case LesserThan(version, lhs, rhs) => if lhs.value < rhs.value then 1 else 0
case Equals(version, lhs, rhs) => if lhs.value == rhs.value then 1 else 0
end Packet

type BinaryData = List[Char]

inline def toInt(chars: BinaryData): Int =
Integer.parseInt(chars.mkString, 2)

inline def toLong(chars: BinaryData): Long =
java.lang.Long.parseLong(chars.mkString, 2)

@tailrec
def readLiteralBody(tail: BinaryData, numAcc: BinaryData): (Long, BinaryData) =
val (num, rest) = tail.splitAt(5)
if num(0) == '1' then readLiteralBody(rest, numAcc.appendedAll(num.drop(1)))
else
val bits = numAcc.appendedAll(num.drop(1))
(toLong(bits), rest)
end readLiteralBody

def readOperatorBody(current: BinaryData): (List[Packet], BinaryData) =
val (lenId, rest) = current.splitAt(1)

@tailrec
def readMaxBits(
current: BinaryData,
remaining: Int,
acc: List[Packet]
): (List[Packet], BinaryData) =
if remaining == 0 then (acc, current)
else
val (newExpr, rest) = decodePacket(current)
readMaxBits(rest, remaining - (current.size - rest.size), acc :+ newExpr)

@tailrec
def readMaxPackets(
current: BinaryData,
remaining: Int,
acc: List[Packet]
): (List[Packet], BinaryData) =
if remaining == 0 then (acc, current)
else
val (newExpr, rest) = decodePacket(current)
readMaxPackets(rest, remaining - 1, acc :+ newExpr)

lenId match
// read based on length
case List('0') =>
val (size, packets) = rest.splitAt(15)
readMaxBits(packets, toInt(size), Nil)

// read based on number of packages
case _ =>
val (size, packets) = rest.splitAt(11)
readMaxPackets(packets, toInt(size), Nil)
end match
end readOperatorBody

def decodePacket(packet: BinaryData): (Packet, BinaryData) =
val (versionBits, rest) = packet.splitAt(3)
val version = toInt(versionBits)
val (typeBits, body) = rest.splitAt(3)
val tpe = toInt(typeBits)

tpe match
case 4 =>
val (value, remaining) = readLiteralBody(body, Nil)
(Packet.Literal(version, value), remaining)
case otherTpe =>
val (values, remaining) = readOperatorBody(body)
otherTpe match
case 0 => (Packet.Sum(version, values), remaining)
case 1 => (Packet.Product(version, values), remaining)
case 2 => (Packet.Minimum(version, values), remaining)
case 3 => (Packet.Maximum(version, values), remaining)
case 5 => (Packet.GreaterThan(version, values(0), values(1)), remaining)
case 6 => (Packet.LesserThan(version, values(0), values(1)), remaining)
case 7 => (Packet.Equals(version, values(0), values(1)), remaining)
end match
end decodePacket

def parse(input: String) =
val number = input.toList.flatMap(hex => hexadecimalMapping(hex).toCharArray)
val (operator, _) = decodePacket(number)
operator

def part1(input: String) =
val packet = parse(input)
packet.versionSum

def part2(input: String) =
val packet = parse(input)
packet.value
end part2

You might have noticed that we had to slightly modify the versionsSum function to work with our new structure.

Solutions from the community

Share your solution to the Scala community by editing this page.