------------------------------------------------------------------------- -- Pattern-matching tutorial -- (c) F. Fleutot, 2007. -- -- This is the code corresponding to the pattern matching tutorial included -- in Metalua manual. The compile-time and run-time parts of it are put -- in a single file, which is *Bad*. -- -- If you want to see a more advanced example, implementing most of -- the improvements suggested in the tutorial, look at the version in the -- standard libraries directory. ------------------------------------------------------------------------- -{ block: ---------------------------------------------------------------------- -- Convert a tested term and a list of (pattern, statement) pairs -- into a pattern-matching AST. ---------------------------------------------------------------------- function match_parser (tested_term, cases) ------------------------------------------------------------------- -- generate a variable from a number, and remember the biggest number -- it ever received into [max_n]. ------------------------------------------------------------------- local max_n = 0 local function var (n) if n > max_n then max_n = n end return `Id{ "v" .. n } end ------------------------------------------------------------------- -- Accumulate bits of code representing patterns ------------------------------------------------------------------- local acc = { } local function accumulate (it) table.insert (acc, it) end ------------------------------------------------------------------- -- Turn a pattern into a list of conditions and assignations, -- stored into [acc]. [n] is the depth of the subpattern into the -- toplevel pattern; [tested_term] is the AST of the term to be -- tested; [pattern] is the AST of a pattern, or a subtree of that -- pattern when [n>0]. ------------------------------------------------------------------- local function parse_pattern (n, pattern) local v = var(n) if "Number" == pattern.tag or "String" == pattern.tag then accumulate (+{ -{v} == -{pattern} }) elseif "Id" == pattern.tag then accumulate (+{stat: local -{pattern} = -{v} }) elseif "Table" == pattern.tag then accumulate (+{ type( -{v} ) == "table" } ) local idx = 1 for _, x in ipairs (pattern) do local w = var(n+1) local key, sub_pattern if x.tag=="Key" then key = x[1]; sub_pattern = x[2] else key = `Number{ idx }; sub_pattern = x; idx=idx+1 end accumulate (+{stat: (-{w}) = -{v} [-{key}] }) accumulate (+{ -{w} ~= nil }) parse_pattern (n+1, sub_pattern) end else error "Invalid pattern type" end end ------------------------------------------------------------------- -- Turn a list of tests and assignations into [acc] into a -- single term of nested conditionals and assignments. -- [inner_term] is the AST of a term to be put into the innermost -- conditionnal, after all assignments. [n] is the index in [acc] -- of the term currently parsed. -- -- This is a recursive function, which builds the inner part of -- the statement first, then surrounds it with nested -- [if ... then ... end], [local ... = ...] and [let ... = ...] -- statements. ------------------------------------------------------------------- local function collapse (n, inner_term) assert (not inner_term.tag, "collapse inner term must be a block") if n > #acc then return inner_term end local it = acc[n] local inside = collapse (n+1, inner_term) assert (not inside.tag, "collapse must produce a block") if "Op" == it.tag then -- [it] is a test, put it in an [if ... then .. end] statement return +{block: if -{it} then -{inside} end } else -- [it] is a statement, just add it at the result's beginning. assert ("Let" == it.tag or "Local" == it.tag) return { it, unpack (inside) } end end ------------------------------------------------------------------- -- parse all [pattern ==> block] pairs. Result goes in [body]. ------------------------------------------------------------------- local body = { } for _, case in ipairs (cases) do acc = { } -- reset the accumulator parse_pattern (1, case[1], var(1)) -- fill [acc] with conds and lets local last_stat = case[2][#case[2]] if last_stat and last_stat.tag ~= "Break" and last_stat.tag ~= "Return" then table.insert (case[2], `Break) -- to skip other cases on success end local compiled_case = collapse (1, case[2]) for _, x in ipairs (compiled_case) do table.insert (body, x) end end local local_vars = { } for i = 1, max_n do table.insert (local_vars, var(i)) end ------------------------------------------------------------------- -- cases are put inside a [repeat until true], so that the [break] -- inserted after the value will jump after the last case on success. ------------------------------------------------------------------- local result = +{ stat: repeat -{ `Local{ local_vars, { } } } (-{var(1)}) = -{tested_term} -{ body } until true } return result end ---------------------------------------------------------------------- -- Sugar coating: add the syntactic extension that makes pattern -- matching pleasant to read and write. ---------------------------------------------------------------------- mlp.lexer:add{ "match", "with", "->" } mlp.block.terminators:add "|" mlp.stat:add{ "match", mlp.expr, "with", gg.optkeyword "|", gg.list{ gg.sequence{ mlp.expr, "->", mlp.block }, separators = "|", terminators = "end" }, "end", builder = |x| match_parser (x[1], x[3]) } } -- End of compile-time stage ------------------------------------------------------------------------- -- Test: an evaluator for pre-parsed arithmetic expressions. ------------------------------------------------------------------------- binopfunc = { Add = |x, y| x+y; Sub = |x, y| x-y; Mul = |x, y| x*y; Div = |x, y| x/y } values = { a=3; b=5; c=7 } ------------------------------------------------------------------------- -- We rely on metalua quote operator to buid the term, which will look -- like `Op{ `Add, `Op{ `Add, `Op{ `Add, `Number 1, `Op{ `Mul, ... }}}} -- That's a bit hackish, but that will do for a quick&dirty test case. ------------------------------------------------------------------------- test_term = +{ 1 + 2*a + 4*b + 6*c } function eval(t) match t with | `Op{ op, a, b } -> return binopfunc[op.tag](eval(a), eval(b)) | `Number{ n } -> return n | `Id{ v } -> return values[v] | _ -> print("Can't evaluate") end end assert (eval(test_term) == 69, "Test failed") print "Test passed"