{-# LANGUAGE
    TypeSynonymInstances
  , FlexibleInstances
  , MultiParamTypeClasses
  , FlexibleContexts
  #-}

module LText.Internal.Inference where

import LText.Internal.Expr
import LText.Internal.Types
import LText.Internal.Classes

import           Control.Monad.State
import           Control.Monad.Except
import qualified Data.Map                   as Map
import qualified Data.Set                   as Set


newtype Context = Context (Map.Map ExprVar Prenex)

remove :: Context -> String -> Context
remove (Context env) var = Context (Map.delete var env)


instance Bindable Set.Set TypeVar Context where
  fv (Context env) = fv (Map.elems env)

instance Substitutable Map.Map TypeVar Type Context where
  apply s (Context env) = Context (fmap (apply s) env)


-- | Binds free type variables as universally quantified
generalize :: Context -> Type -> Prenex
generalize env t = Prenex vars t
  where vars = Set.toList $ fv t `difference` fv env

data TIState = TIState { tiSupply :: Int
                       , tiSubst  :: Subst TypeVar Type
                       }

-- | Inference Monad
type TI m a = ExceptT String (StateT TIState m) a

runTI :: ( Monad m
         , MonadError String m
         ) => StateT TIState m a -> m a
runTI t = evalStateT t initTIState
  where initTIState = TIState { tiSupply = 0
                              , tiSubst = nullSubst
                              }


newTyVar :: ( Monad m
            , MonadState TIState m
            , MonadError String m
            ) => TypeVar -> m Type
newTyVar prefix = do
  s <- get
  put s { tiSupply = tiSupply s + 1
        }
  return $ TVar $ prefix ++ show (tiSupply s)

-- | Replaces bound type variables with free, fresh ones
instantiate :: ( Monad m
               , MonadState TIState m
               , MonadError String m
               ) => Prenex -> m Type
instantiate (Prenex vars t) = do
  nvars <- mapM newTyVar vars
  return $ apply (Map.fromList $ zip vars nvars) t

-- | Most general unifier
mgu :: ( Monad m
       , MonadState TIState m
       , MonadError String m
       ) => Type -> Type -> m (Subst TypeVar Type)
mgu (TFun l r) (TFun l' r')  = do s1 <- mgu l l'
                                  s2 <- mgu (apply s1 r) (apply s1 r')
                                  return (s1 `composeSubst` s2)
mgu (TVar u) t               = varBind u t
mgu t (TVar u)               = varBind u t
mgu TText TText              = return nullSubst
mgu t1 t2                    = throwError $ "Types do not unify: " ++ show t1 ++
                                        " vs. " ++ show t2

-- | Makes a substitution @[x -> t]@
varBind :: ( Monad m
           , MonadState TIState m
           , MonadError String m
           ) => TypeVar -> Type -> m (Subst TypeVar Type)
varBind u t | t == TVar u         = return nullSubst
            | u `Set.member` fv t = throwError $ "Occur check fails: " ++ u ++
                                             " vs. " ++ show t
            | otherwise           = return (Map.singleton u t)


-- | Type inference function
ti :: ( Monad m
      , MonadState TIState m
      , MonadError String m
      ) => Context -> Expr -> m (Subst TypeVar Type, Type)
ti (Context env) (EVar n) = case Map.lookup n env of
  Nothing     ->  throwError $ "unbound variable: " ++ n
  Just sigma  ->  do  t <- instantiate sigma
                      return (nullSubst, t)

ti env (EAbs n e) = do
  tv <- newTyVar "a"
  let Context env' = remove env n -- replace `n`'s type with a free type variable
      env''        = Context $ env' `Map.union` Map.singleton n (Prenex [] tv)
  (s1, t1) <- ti env'' e
  return (s1, TFun (apply s1 tv) t1)

ti env (EApp e1 e2) = do
  (s1, t1) <- ti env e1
  (s2, t2) <- ti (apply s1 env) e2
  tv       <- newTyVar "a"
  s3       <- mgu (apply s2 t1) (TFun t2 tv)
  return (s3 `composeSubst` s2 `composeSubst` s1, apply s3 tv)

ti env (ELet x e1 e2) = do
  (s1, t1) <- ti env e1
  let Context env' = remove env x
      t'           = generalize (apply s1 env) t1
      env''        = Context (Map.insert x t' env')
  (s2, t2) <- ti (apply s1 env'') e2
  return (s1 `composeSubst` s2, t2)

ti _ (EText _) = return (nullSubst, TText)

ti env (EConc e1 e2) = do
  (s1, t1) <- ti env e1
  (s2, t2) <- ti env e2
  s3 <- mgu t1 t2
  s4 <- mgu (apply s3 t1) TText
  return (s4 `composeSubst` s3 `composeSubst` s2 `composeSubst` s1, TText)


typeInference :: ( Monad m
                 , MonadState TIState m
                 , MonadError String m
                 ) => Context -> Expr -> m Type
typeInference env e = do
  (s, t) <- ti env e
  return (apply s t)


-- isSubtypeOf :: Type -> Type -> Bool
-- isSubtypeOf _ (TVar _) = True
-- isSubtypeOf (TFun a b) (TFun c d) | c == d = a == b
--                                           && isSubtypeOf a c
--                                           && isSubtypeOf b d
--                                   | otherwise = isSubtypeOf a c
--                                              && isSubtypeOf b d
-- isSubtypeOf TText _ = True
-- isSubtypeOf _ _ = False


test :: ( Monad m
        , MonadIO m
        ) => Expr -> m ()
test e = do
  eRes <- runExceptT $ runTI $ typeInference (Context Map.empty) e
  case eRes of
    Left err -> liftIO $ putStrLn $ "error: " ++ err
    Right t  -> let q = generalize (Context Map.empty) t in
      liftIO $ putStrLn $ show e ++ " :: " ++ show q