The newest version of Scala, 2.10, will allow to user to execute code at compile time with Macros, which are just normal Scala functions. You can get an overview of the possibilities at the official Scala Macro guide. Right now, Scala 2.10 is still under development, so you will have to use a Release Candidate to follow along.
First of all, a good system of measurements should make it easy to define values with units in code - no complicated trickery should be required. That's easy to achieve with a small utility macro:
val a = u(42, "m/s")
It simply parses the unit string and turns it into the correct types. The call above will return the number 42 with the desired unit:
val a = new MeasuredNumber[Times[Meter, Second]](42))
To actually implement this utility macro u(n, unit)
, we follow a few simple steps:
- precompute all values we use (in this case,
n
) and store a reference to them - parse the type string
unit
and simplify it, then turn it into a tree of type expressions - use the simplified unit and the precomputed value to build a Tree (that's the data structure representing Scala code at compile time) that creates the correct instance of MeasuredNumber
Precomputation
This is the simplest step, we simply turn each expression into a new value definition (val a = n
) and return a reference to it. We just need to make sure to use unique identifiers (that's the c.fresh()
call) and to actually include the value definitions in the tree the macro returns.
val evals = ListBuffer[ValDef]()
def _precompute(value: Tree, tpe: Type): Ident = {
val freshName = newTermName(c.fresh("eval$"))
evals += ValDef(Modifiers(), freshName, TypeTree(tpe), value)
Ident(freshName)
}
Unit String Parsing
This step is more complicated. First of all, we parse the provided unit string (like "m/s") with parse combinators to get a recursive structure of case classes we can handle more easily:
import scala.util.parsing.combinator._
object UnitParser extends JavaTokenParsers {
def parse(in: String, c: Context) = parseAll(unit, in) match {
case Success(r, _) => reduce(r.simplify).toTree(c)
case _ => c.abort(c.enclosingPosition, "unknown units and/or invalid format")
}
def toTimes(units: List[String]): Typename = units match {
case x :: Nil => SimpleType(x)
case x :: y :: Nil => Times(SimpleType(x), SimpleType(y))
case x :: xs => Times(SimpleType(x), toTimes(xs))
}
/* either a*b*c or a*b/c*d */
def unit: Parser[Typename] = (
rep1sep(unitname, "*")~"/"~rep1sep(unitname, "*") ^^ { case times~"/"~divide =>
Divide(toTimes(times), toTimes(divide)) }
| rep1sep(unitname, "*") ^^ { case units => toTimes(units)}
)
def unitname: Parser[String] = (
"m" ^^ { _ => "Meter" }
| "s" ^^ { _ => "Second" }
| "1"
)
}
Note: This design is directly based on types used to actually represent the units in the compiler (Times
, Divide
and SimpleType
), I might move to simpler design in the future.
You'll notice the call to reduce(r.simplify).toTree(c)
in the parse
function. The simplify
instance method eliminates all instances of Times
and Divide
and turns the whole tree into a long list of either SimpleType
or Inverted
(for "1/s"). Then, reduce(l: List[Typename])
will eliminate any and all redundant units, so "m*s/s" will turn into "m", and turn the list back into a tree of Times
classes, with at most one Divide
class at the outermost level. Finally, toTree(c: Context)
converts the tree from my custom case classes into a type tree understandable by the compiler.
As an example, here's the implementation of simplify
and toTree
for the Divide
case class. Note that we avoid redundant instances of Inverted
, since Inverted(Inverted(n))
is equivalent to just n
:
case class Divide[T, U](param1: Typename, param2: Typename) extends Typename {
override def toString = s"($param1) / ($param2)"
override def simplify = {
param1.simplify ++ param2.simplify.map(t => t match {
case Inverted(t) => t
case t => Inverted(t)
})
}
override def toTree(c: Context): c.universe.Tree = c.universe.AppliedTypeTree(
c.universe.Ident(c.universe.newTypeName("Divide")),
List(param1.toTree(c), param2.toTree(c)))
}
Building a Tree for the Compiler
Finally, we need to build a Tree that will create the correct instance for MeasuredNumber. Now that we have the correct type, this is fairly straightforward. parsedUnit
is the type tree obtained in the step above, nID
is the reference to the value we precomputed in the first step.
val stats = Apply(Select(New(AppliedTypeTree(
Ident(newTypeName("MeasuredNumber")),
List( parsedUnit )
)), nme.CONSTRUCTOR),
List(Ident(newTermName(nID.toString))))
c.Expr(Block(evals.toList, stats))
If you are knew to macros, it can be very helpful to inspect code generated by the compiler. There are two main functions that will help you here:
reify
takes an expression and will generate the correctly qualified Tree for it.showRaw
returns a string representing a given Tree / Expr. This can be very helpful in combination with reify, if you have a general concept of the code you want but are unsure about the exact classes to use.
Code
You can find the complete code for this post on GitHub, including two more macros that allow you to add and multiply MeasuredNumbers (with type errors in case of incompatible units!). I'll talk about those next time, but for now, here's the code: Units of Measure
Tags: programming, scala