----------------------------------------------------------------------
-- Metalua tutorial.
--
-- Summary: Runtime type-checking for Metalua.
--
----------------------------------------------------------------------
--
-- Copyright (c) 2006-2007, Fabien Fleutot <metalua@gmail.com>.
--
-- This software is released under the MIT Licence, see licence.txt
-- for details.
--
--------------------------------------------------------------------------------
-- 
-- This extension adds dynamic type-checking to Lua.
-- Type-checking functions are stored in a table called [typecheck]; such
-- functions take the value to check as input, and either return it, or
-- cause a typing error by calling [type_error (faulty_value, msg)].
-- 
-- Usage
-- =====
-- 
-- The user can create type-checking functions, and register them in the
-- table; functions are provided for [string], [boolean], [number] and
-- [notnil].
--
-- Parametric types
-- ================
--
-- It is not uncommon for a type to take another type as a parameter, e.g.
-- [list (integer)]. Therefore, type declaration should be understood in
-- a dedicated environment, where variables are taken from the [typecheck]
-- table. This can be done by:
-- - Putting all type checking operation in anonymous function bodies
-- - Changing these anonymous function environments with [setfenv()]
--
-- However, Creating and calling such functions, with upvalues, for every
-- typecheck seems grossly inefficient, so instead we parse typechecking
-- code at compile time, to replace identifiers by indexes in the table.
-- For instance, [list (integer)] is transformed into:
-- [typecheck.list (typecheck.integer)]. This way, most of the overhead
-- is taken at compile time rather than runtime.
--
-- Syntax
-- ======
--
-- Type-checking is performed by operator '::'. It can be used:
-- - on function definition parameters: "function f (x :: number)"
-- - to give function result type: "function f (x :: list(string)) :: integer"
-- - on arbitrary expressions: "x = y :: integer + 2"
--
-------------------------------------------------------------------------------
-- TODO: 
--
-- - Local variable should be typechecked as well, i.e. their content should
--   be typechecked after each assignment operation. This requires AST walking:
--   cf walk.lua.
--
-- - It must be possible to turn off type-checking at compile time. Some finer
--   grained mechanism (e.g. disable internal type-checking while keeping it
--   for API functions) might prove useful.
--
--------------------------------------------------------------------------------

-{ block:

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--
-- Meta-stage: code generators and syntax extensions 
--
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

--------------------------------------------------------------------------
-- Transform all identifiers in [ast] into an index of [typecheck], e.g.
-- [foo] is transformed into [typecheck.foo]. This allows to change the
-- namespace for type declarations without resorting on function
-- environment hacks.
--------------------------------------------------------------------------
local function index_id_transformer (ast)
   if type (ast) ~= "table" then return end
   if ast.tag == "Id" then ast <- `Index{ `Id "typecheck", `String{ ast[1] } }
   else table.iforeach(index_id_transformer, ast) end
end

--------------------------------------------------------------------------
-- Insert a type-checking function call on [term] before returning
-- [term]'s value. Only legal in an expression context.
--------------------------------------------------------------------------
local function insert_test (term, type)
   index_id_transformer (type)
   if false and { Id=1; Boolean=1; Number=1; String=1 } [term.tag] then
      -- [term] is atomic, it's useless to save it in a tmp variable
      return `Stat{ { `Call{ type, term } }, term }
   else      
      -- [term] is not atomic --> save it in a local var, to make
      -- sure that it won't evaluate twice
      local var = mlp.gensym()
      return `Stat{ { `Local{ {var}, {term} }; `Call{ type, var } }, var }
   end
end

--------------------------------------------------------------------------
-- Parse an AST to insert typechecking [ast_type] in every return statements
--------------------------------------------------------------------------
local function typed_return_transformer (ast_val, ast_type)
   if type (ast_val) == "table" then
      if ast_val.tag == "Return" then
         if #ast_val ~= 1 then 
            error "Only single-value returns can be typechecked" end
         ast_val[1] = insert_test (ast_val[1], ast_type)
      elseif ast_val.tag ~= "Function" then -- don't touch returns from subfunctions
         for _, t in ipairs (ast_val) do typed_return_transformer (t, ast_type) end
      end
   end
end

--------------------------------------------------------------------------
-- Parse the typechecking tests in a function definition, and adds the 
-- corresponding tests at the beginning of the function's body.
--------------------------------------------------------------------------
local function func_val_builder (x)
   local typed_params, opt_type, body = x[1], x[2], x[3]
   local untyped_params, checks = { }, { }

   -- Build the list of type tests to pass into [checks],
   -- and the list of untyped parameter names
   for i, y in ipairs (typed_params) do
      local param, type = y[1], y[2]
      untyped_params[i] = param
      if type then 
         index_id_transformer (type)
         local q = `Call{ type, param }
         table.insert (checks, q) 
      end
   end

   -- Add the tests to the body
   for i = #checks, 1, -1 do table.insert (body, 1, checks[i]) end

   -- Handle returned type
   if opt_type then
      -- Make sure there is a final return
      if body[#body].tag ~= "Return" then 
         error "Final return is mandatory in typed functions" end
      typed_return_transformer (body, opt_type)
   end 

   -- Build and return  the resulting AST
   return `Function{ untyped_params, body }
end

--------------------------------------------------------------------------
-- Updated function definition parser, which accepts typed vars as
-- parameters.
--------------------------------------------------------------------------
mlp.func_val = gg.sequence{
   "(", gg.list{ gg.sequence{ mlp.id, gg.onkeyword{ "::", mlp.expr } },
      terminators = ")", separators  = "," },
   ")",  gg.onkeyword{ "::", mlp.expr }, mlp.block, "end", 
   builder = func_val_builder }

mlp.lexer:add "::"

--------------------------------------------------------------------------
-- Register as an operator
--------------------------------------------------------------------------
mlp.expr.infix:add{ "::", prec=100, builder = |x| insert_test(x[1],x[3]) }

} -- end of meta-stage, back to normal programming

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--
-- Standard type checkers.
--
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------
-- Initialize the typecheck table. It has an __index entry in its metatable,
-- so that if a symbol is not found in it, it is looked for in the current
-- environment. It allows to write things like [ n=3; x :: vector(n) ].
--------------------------------------------------------------------------------
typecheck = { }
setmetatable (typecheck, { __index = getfenv(0)})

--------------------------------------------------------------------------------
-- Built-in types
--------------------------------------------------------------------------------
local atomic_types = { "number", "string", "boolean", "table" }
for _, typename in ipairs (atomic_types) do
   typecheck[typename] = 
      function (val)           
         if type(val) ~= typename then error (typename .. " expected") end
      end
end

--------------------------------------------------------------------------------
-- [list (subtype)] checks that the term is a table, and all of its 
-- integer-indexed elements are of type [subtype].
--------------------------------------------------------------------------------
typecheck.list  = |elem_type| function (val :: table)
   for _, v in ipairs (val) do elem_type (v) end end

--------------------------------------------------------------------------------
-- [inter (x, y)] checks that the term has both types [x] and [y].
--------------------------------------------------------------------------------
typecheck.inter = |x, y| function (val) 
   x(val); y(val) end

--------------------------------------------------------------------------------
-- [inter (x, y)] checks that the term has type either [x] or [y].
--------------------------------------------------------------------------------
typecheck.union = |x, y| function (val)
   if not pcall (x, val) then y (val) end end

--------------------------------------------------------------------------------
-- [vector (n)] checks that the term is a table, which has [n] integer-indexed
-- elements.
-- To check also their types, one can use inter, as in:
-- [inter (vector(3), list (number))]
--------------------------------------------------------------------------------
typecheck.vector = |size| function (val :: table)
   if #val ~= size then error "bad vector size" end end

--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
--
-- Code sample
--
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

function sum (x :: list(number)) :: number
   local s = 0
   for _, n in ipairs (x) do s = s+n end
   return s
end

list = {1, 2, 3, 4, 5}
list_sum = sum (list)
print ("Sum value:", list_sum)