-------------------------------------------------------------------------
-- Pattern-matching tutorial   --   (c) F. Fleutot, 2007.
--
-- This is the code corresponding to the pattern matching tutorial included
-- in Metalua manual. The compile-time and run-time parts of it are put
-- in a single file, which is *Bad*.
--
-- If you want to see a more advanced example, implementing most of
-- the improvements suggested in the tutorial, look at the version in the
-- standard libraries directory.
-------------------------------------------------------------------------

-{ block:

   ----------------------------------------------------------------------
   -- Convert a tested term and a list of (pattern, statement) pairs
   -- into a pattern-matching AST.
   ----------------------------------------------------------------------
   function match_parser (tested_term, cases)

      -------------------------------------------------------------------
      -- generate a variable from a number, and remember the biggest number
      -- it ever received into [max_n].
      -------------------------------------------------------------------
      local max_n = 0
      local function var (n) 
         if n > max_n then max_n = n end
         return `Id{ "v" .. n } 
      end
      
      -------------------------------------------------------------------
      -- Accumulate bits of code representing patterns
      -------------------------------------------------------------------
      local acc = { }
      local function accumulate (it) table.insert (acc, it) end

      -------------------------------------------------------------------
      -- Turn a pattern into a list of conditions and assignations,
      -- stored into [acc]. [n] is the depth of the subpattern into the
      -- toplevel pattern; [tested_term] is the AST of the term to be 
      -- tested; [pattern] is the AST of a pattern, or a subtree of that
      -- pattern when [n>0].
      -------------------------------------------------------------------
      local function parse_pattern (n, pattern)
         local v = var(n)
         if "Number" == pattern.tag or "String" == pattern.tag then
            accumulate (+{ -{v} == -{pattern} })
         elseif "Id" == pattern.tag then
            accumulate (+{stat: local -{pattern} = -{v} })
         elseif "Table" == pattern.tag then
            accumulate (+{ type( -{v} ) == "table" } )
            local idx = 1
            for _, x in ipairs (pattern) do
               local w = var(n+1)
               local key, sub_pattern
               if x.tag=="Key" 
               then key = x[1];           sub_pattern = x[2]
               else key = `Number{ idx }; sub_pattern = x; idx=idx+1 end
               accumulate (+{stat: (-{w}) = -{v} [-{key}] })
               accumulate (+{ -{w} ~= nil })
               parse_pattern (n+1, sub_pattern)
            end
         else error "Invalid pattern type" end
      end

      -------------------------------------------------------------------
      -- Turn a list of tests and assignations into [acc] into a
      -- single term of nested conditionals and assignments.
      -- [inner_term] is the AST of a term to be put into the innermost
      -- conditionnal, after all assignments. [n] is the index in [acc]
      -- of the term currently parsed.
      -- 
      -- This is a recursive function, which builds the inner part of
      -- the statement first, then surrounds it with nested 
      -- [if ... then ... end], [local ... = ...] and [let ... = ...]
      -- statements.
      -------------------------------------------------------------------
      local function collapse (n, inner_term)
         assert (not inner_term.tag, "collapse inner term must be a block")
         if n > #acc then return inner_term end
         local it = acc[n]
         local inside = collapse (n+1, inner_term)
         assert (not inside.tag, "collapse must produce a block")
         if "Op" == it.tag then 
            -- [it] is a test, put it in an [if ... then .. end] statement
            return +{block: if -{it} then -{inside} end }
         else 
            -- [it] is a statement, just add it at the result's  beginning.
            assert ("Let" == it.tag or "Local" == it.tag)
            return { it, unpack (inside) }
         end
      end
      
      -------------------------------------------------------------------
      -- parse all [pattern ==> block] pairs. Result goes in [body].
      -------------------------------------------------------------------
      local body = { }
      for _, case in ipairs (cases) do
         acc = { } -- reset the accumulator
         parse_pattern (1, case[1], var(1)) -- fill [acc] with conds and lets
         local last_stat = case[2][#case[2]]
         if last_stat and last_stat.tag ~= "Break" and 
            last_stat.tag ~= "Return" then
            table.insert (case[2], `Break) -- to skip other cases on success
         end
         local compiled_case = collapse (1, case[2])
         for _, x in ipairs (compiled_case) do table.insert (body, x) end
      end
      
      local local_vars = { }
      for i = 1, max_n do table.insert (local_vars, var(i))  end
      
      -------------------------------------------------------------------
      -- cases are put inside a [repeat until true], so that the [break]
      -- inserted after the value will jump after the last case on success.
      -------------------------------------------------------------------
      local result = +{ stat: 
         repeat
            -{ `Local{ local_vars, { } } }
            (-{var(1)}) = -{tested_term}
            -{ body }
         until true }
      return result
   end
   
   ----------------------------------------------------------------------
   -- Sugar coating: add the syntactic extension that makes pattern
   -- matching pleasant to read and write.
   ----------------------------------------------------------------------
   
   mlp.lexer:add{ "match", "with", "->" }
   mlp.block.terminators:add "|"
   
   mlp.stat:add{
      "match", mlp.expr, "with",
      gg.optkeyword "|",
      gg.list{ 
         gg.sequence{ mlp.expr, "->", mlp.block },
         separators  = "|",
         terminators = "end" },
      "end",
      builder = |x| match_parser (x[1], x[3]) }  

} -- End of compile-time stage

-------------------------------------------------------------------------
-- Test: an evaluator for pre-parsed arithmetic expressions.
-------------------------------------------------------------------------
binopfunc = {
   Add = |x, y| x+y; Sub = |x, y| x-y;
   Mul = |x, y| x*y; Div = |x, y| x/y }

values = { a=3; b=5; c=7 }

-------------------------------------------------------------------------
-- We rely on metalua quote operator to buid the term, which will look
-- like `Op{ `Add, `Op{ `Add, `Op{ `Add, `Number 1, `Op{ `Mul, ... }}}}
-- That's a bit hackish, but that will do for a quick&dirty test case.
-------------------------------------------------------------------------
test_term = +{ 1 + 2*a + 4*b + 6*c }

function eval(t)
   match t with
   | `Op{ op, a, b } -> return binopfunc[op.tag](eval(a), eval(b))
   | `Number{ n }    -> return n
   | `Id{ v }        -> return values[v]
   | _               -> print("Can't evaluate")
   end
end

assert (eval(test_term) == 69, "Test failed")
print "Test passed"