(**************************************************************
 *  CS101: Homework 4 --- Type checker and interpreter        *
 **************************************************************)

open Cs101_hw4_p2_types

type type_environment = lc_type VarMap.t

(*
 * Lambda-calculus type inference.
 * Raises LcError if the input is not well-typed.
 *)
let rec lc_get_type gamma e =
   match e with
      ExpInt _ ->
         TyInt
    | ExpVar v ->
         begin 
            try VarMap.find v gamma
            with
               Not_found -> raise(LcError(ErrUnknownVariable v))
         end
    | ExpBinop (_, e1, e2) ->
         begin match lc_get_type gamma e1, lc_get_type gamma e2 with
           TyInt, TyInt -> TyInt
         | TyInt, t2 -> raise (LcError(ErrTypeMismatch(e2, t2, TyInt)))
         | t1, _ -> raise (LcError((ErrTypeMismatch(e1, t1, TyInt))))
        end
    | ExpApply (e1, e2) ->
         let t1 = lc_get_type gamma e1 in
         let t2 = lc_get_type gamma e2 in
            begin match t1 with
               TyFun(t_in, t_out) -> 
                  if t2 = t_in then
                     t_out
                  else 
                     raise (LcError((ErrTypeMismatch(e2, t2, t_in))))
             | _ -> raise (LcError((ErrTypeFunApp(e1, t1))))
            end
    | ExpLambda (v, t, e) -> 
        let ty_res = lc_get_type (VarMap.add v t gamma) e in
           TyFun(t, ty_res)

    | _ -> raise (Invalid_argument "Units, references: not implemented")

(*
 * Lambda-calculus type checker.
 * Raises LcError if the input is not well-typed.
 *)
let lc_type_check =
   lc_get_type VarMap.empty

let decode_binop = function
   OpPlus -> (+)
 | OpMinus -> (-)
 | OpTimes -> ( * )

let rec lc_get_value state (gamma : value_environment) e =
   match e with 
      ExpInt i ->
         { refs = state; value = ValInt i }
    | ExpLambda (v, t, e) ->
         { refs = state; value = ValClosure(gamma, v, e) }
    | ExpVar v -> 
         (*
          * The exression have passed the type checker,
          * we can assume that v is in gamma.
          *)
         { refs = state; value = VarMap.find v gamma }
    | ExpBinop (op, e1, e2) ->
         let v1 = lc_get_value state gamma e1 in
         let v2 = lc_get_value v1.refs gamma e2 in
         begin match v1.value, v2.value with
            ValInt i1, ValInt i2 -> { refs = v2.refs; value = ValInt (decode_binop op i1 i2) }
          | _ -> raise (InternalError "Non-int expressions in a binop")
         end
    | ExpApply (e1, e2) ->
         let v1 = lc_get_value state gamma e1 in
         let v2 = lc_get_value v1.refs gamma e2 in
            begin match v2.value with
               ValClosure(gamma1, x, e1) -> 
                  lc_get_value v2.refs (VarMap.add x v2.value gamma1) e1
             | _ ->
                  raise (InternalError "First exp of application isn't a function.")
            end

    | _ -> raise (Invalid_argument "Units, references: not implemented")

(*
 * Lambda-calculus interpreter.
 * The returned expression must be a value (a number or a lambda).
 * Raises LcError if the input is not well-typed.
 *)
let lc_evaluate e =
   ignore(lc_type_check e);
   lc_get_value Location.empty_state VarMap.empty e
