Posted on Wed 19 December 2012

A System of Measurements in Scala

The newest version of Scala, 2.10, will allow to user to execute code at compile time with Macros, which are just normal Scala functions. You can get an overview of the possibilities at the official [cached]Scala Macro guide. Right now, Scala 2.10 is still under development, so you will have to use a Release Candidate to follow along.

First of all, a good system of measurements should make it easy to define values with units in code - no complicated trickery should be required. That's easy to achieve with a small utility macro:

1
val a = u(42, "m/s")

It simply parses the unit string and turns it into the correct types. The call above will return the number 42 with the desired unit:

1
val a = new MeasuredNumber[Times[Meter, Second]](42))

To actually implement this utility macro u(n, unit), we follow a few simple steps: - precompute all values we use (in this case, n) and store a reference to them - parse the type string unit and simplify it, then turn it into a tree of type expressions - use the simplified unit and the precomputed value to build a Tree (that's the data structure representing Scala code at compile time) that creates the correct instance of MeasuredNumber

Precomputation

This is the simplest step, we simply turn each expression into a new value definition (val a = n) and return a reference to it. We just need to make sure to use unique identifiers (that's the c.fresh() call) and to actually include the value definitions in the tree the macro returns.

1
2
3
4
5
6
7
val evals = ListBuffer[ValDef]()

def _precompute(value: Tree, tpe: Type): Ident = {
  val freshName = newTermName(c.fresh("eval$"))
  evals += ValDef(Modifiers(), freshName, TypeTree(tpe), value)
  Ident(freshName)
}

Unit String Parsing

This step is more complicated. First of all, we parse the provided unit string (like "m/s") with parse combinators to get a recursive structure of case classes we can handle more easily:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import scala.util.parsing.combinator._

object UnitParser extends JavaTokenParsers {
  def parse(in: String, c: Context) = parseAll(unit, in) match {
    case Success(r, _) => reduce(r.simplify).toTree(c)
    case _ => c.abort(c.enclosingPosition, "unknown units and/or invalid format")
  }

  def toTimes(units: List[String]): Typename = units match {
    case x :: Nil => SimpleType(x)
    case x :: y :: Nil => Times(SimpleType(x), SimpleType(y))
    case x :: xs => Times(SimpleType(x), toTimes(xs))
  }

  /* either a*b*c or a*b/c*d */
  def unit: Parser[Typename] = (
      rep1sep(unitname, "*")~"/"~rep1sep(unitname, "*") ^^ { case times~"/"~divide =>
        Divide(toTimes(times), toTimes(divide)) }
    | rep1sep(unitname, "*") ^^ { case units => toTimes(units)}
    )

  def unitname: Parser[String] = (
      "m" ^^ { _ => "Meter" }
    | "s" ^^ { _ => "Second" }
    | "1"
    )
}

Note: This design is directly based on types used to actually represent the units in the compiler (Times, Divide and SimpleType), I might move to simpler design in the future.

You'll notice the call to reduce(r.simplify).toTree(c) in the parse function. The simplify instance method eliminates all instances of Times and Divide and turns the whole tree into a long list of either SimpleType or Inverted (for "1/s"). Then, reduce(l: List[Typename]) will eliminate any and all redundant units, so "m*s/s" will turn into "m", and turn the list back into a tree of Times classes, with at most one Divide class at the outermost level. Finally, toTree(c: Context) converts the tree from my custom case classes into a type tree understandable by the compiler.

As an example, here's the implementation of simplify and toTree for the Divide case class. Note that we avoid redundant instances of Inverted, since Inverted(Inverted(n)) is equivalent to just n:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
case class Divide[T, U](param1: Typename, param2: Typename) extends Typename {
  override def toString = s"($param1) / ($param2)"
  override def simplify = {
    param1.simplify ++ param2.simplify.map(t => t match {
      case Inverted(t) => t
      case t => Inverted(t)
      })
  }
  override def toTree(c: Context): c.universe.Tree = c.universe.AppliedTypeTree(
          c.universe.Ident(c.universe.newTypeName("Divide")),
          List(param1.toTree(c), param2.toTree(c)))
}

Building a Tree for the Compiler

Finally, we need to build a Tree that will create the correct instance for MeasuredNumber. Now that we have the correct type, this is fairly straightforward. parsedUnit is the type tree obtained in the step above, nID is the reference to the value we precomputed in the first step.

1
2
3
4
5
6
7
val stats = Apply(Select(New(AppliedTypeTree(
      Ident(newTypeName("MeasuredNumber")),
      List( parsedUnit )
      )), nme.CONSTRUCTOR),
    List(Ident(newTermName(nID.toString))))

c.Expr(Block(evals.toList, stats))

If you are knew to macros, it can be very helpful to inspect code generated by the compiler. There are two main functions that will help you here:

  • reify takes an expression and will generate the correctly qualified Tree for it.
  • showRaw returns a string representing a given Tree / Expr. This can be very helpful in combination with reify, if you have a general concept of the code you want but are unsure about the exact classes to use.

Code

You can find the complete code for this post on GitHub, including two more macros that allow you to add and multiply MeasuredNumbers (with type errors in case of incompatible units!). I'll talk about those next time, but for now, here's the code: [cached]Units of Measure

Tags: programming, scala

© Julian Schrittwieser. Built using Pelican. Theme by Giulio Fidente on github. .