{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeSynonymInstances #-}

{-# OPTIONS_GHC -Wno-name-shadowing #-}

module Nix.Type.Infer (
  Constraint(..),
  TypeError(..),
  InferError(..),
  Subst(..),
  inferTop
) where

import           Control.Applicative
import           Control.Arrow
import           Control.Monad.Catch
import           Control.Monad.Except
import           Control.Monad.Logic
import           Control.Monad.Reader
import           Control.Monad.ST
import           Control.Monad.State
import           Data.Fix
import           Data.Foldable
import qualified Data.HashMap.Lazy as M
import           Data.List (delete, find, nub, intersect, (\\))
import           Data.Map (Map)
import qualified Data.Map as Map
import           Data.Maybe (fromJust)
import           Data.STRef
import           Data.Semigroup
import qualified Data.Set as Set
import           Data.Text (Text)
import           Nix.Atoms
import           Nix.Convert
import           Nix.Eval (MonadEval(..))
import qualified Nix.Eval as Eval
import           Nix.Expr.Types
import           Nix.Expr.Types.Annotated
import           Nix.Scope
import           Nix.Thunk
import qualified Nix.Type.Assumption as As
import           Nix.Type.Env
import qualified Nix.Type.Env as Env
import           Nix.Type.Type
import           Nix.Utils

-------------------------------------------------------------------------------
-- Classes
-------------------------------------------------------------------------------

-- | Inference monad
newtype Infer s a = Infer
    { getInfer ::
        ReaderT (Set.Set TVar, Scopes (Infer s) (JThunk s))
            (StateT InferState (ExceptT InferError (ST s))) a
    }
    deriving (Functor, Applicative, Alternative, Monad, MonadPlus, MonadFix,
              MonadReader (Set.Set TVar, Scopes (Infer s) (JThunk s)),
              MonadState InferState, MonadError InferError)

-- | Inference state
newtype InferState = InferState { count :: Int }

-- | Initial inference state
initInfer :: InferState
initInfer = InferState { count = 0 }

data Constraint
    = EqConst Type Type
    | ExpInstConst Type Scheme
    | ImpInstConst Type (Set.Set TVar) Type
    deriving (Show, Eq, Ord)

newtype Subst = Subst (Map TVar Type)
  deriving (Eq, Ord, Show, Semigroup, Monoid)

class Substitutable a where
  apply :: Subst -> a -> a

instance Substitutable TVar where
  apply (Subst s) a = tv
    where t = TVar a
          (TVar tv) = Map.findWithDefault t a s

instance Substitutable Type where
  apply _ (TCon a)           = TCon a
  apply s (TSet b a)         = TSet b (M.map (apply s) a)
  apply s (TList a)          = TList (map (apply s) a)
  apply (Subst s) t@(TVar a) = Map.findWithDefault t a s
  apply s (t1 `TArr` t2)     = apply s t1 `TArr` apply s t2
  apply s (TMany ts)         = TMany (map (apply s) ts)

instance Substitutable Scheme where
  apply (Subst s) (Forall as t) = Forall as $ apply s' t
    where s' = Subst $ foldr Map.delete s as

instance Substitutable Constraint where
   apply s (EqConst t1 t2)         = EqConst (apply s t1) (apply s t2)
   apply s (ExpInstConst t sc)     = ExpInstConst (apply s t) (apply s sc)
   apply s (ImpInstConst t1 ms t2) = ImpInstConst (apply s t1) (apply s ms) (apply s t2)

instance Substitutable a => Substitutable [a] where
  apply = map . apply

instance (Ord a, Substitutable a) => Substitutable (Set.Set a) where
  apply = Set.map . apply


class FreeTypeVars a where
  ftv :: a -> Set.Set TVar

instance FreeTypeVars Type where
  ftv TCon{}         = Set.empty
  ftv (TVar a)       = Set.singleton a
  ftv (TSet _ a)     = Set.unions (map ftv (M.elems a))
  ftv (TList a)      = Set.unions (map ftv a)
  ftv (t1 `TArr` t2) = ftv t1 `Set.union` ftv t2
  ftv (TMany ts)     = Set.unions (map ftv ts)

instance FreeTypeVars TVar where
  ftv = Set.singleton

instance FreeTypeVars Scheme where
  ftv (Forall as t) = ftv t `Set.difference` Set.fromList as

instance FreeTypeVars a => FreeTypeVars [a] where
  ftv   = foldr (Set.union . ftv) Set.empty

instance (Ord a, FreeTypeVars a) => FreeTypeVars (Set.Set a) where
  ftv   = foldr (Set.union . ftv) Set.empty


class ActiveTypeVars a where
  atv :: a -> Set.Set TVar

instance ActiveTypeVars Constraint where
  atv (EqConst t1 t2)         = ftv t1 `Set.union` ftv t2
  atv (ImpInstConst t1 ms t2) = ftv t1 `Set.union` (ftv ms `Set.intersection` ftv t2)
  atv (ExpInstConst t s)      = ftv t `Set.union` ftv s

instance ActiveTypeVars a => ActiveTypeVars [a] where
  atv = foldr (Set.union . atv) Set.empty

data TypeError
  = UnificationFail Type Type
  | InfiniteType TVar Type
  | UnboundVariables [Text]
  | Ambigious [Constraint]
  | UnificationMismatch [Type] [Type]
  deriving (Eq, Show)

data InferError
  = TypeInferenceErrors [TypeError]
  | TypeInferenceAborted
  | forall s. Exception s => EvaluationError s

typeError :: MonadError InferError m => TypeError -> m ()
typeError err = throwError $ TypeInferenceErrors [err]

deriving instance Show InferError
instance Exception InferError

instance Semigroup InferError where
    x <> _ = x

instance Monoid InferError where
    mempty  = TypeInferenceAborted
    mappend = (<>)

-------------------------------------------------------------------------------
-- Inference
-------------------------------------------------------------------------------

-- | Run the inference monad
runInfer' :: Infer s a -> ST s (Either InferError a)
runInfer' = runExceptT
          . (`evalStateT` initInfer)
          . (`runReaderT` (Set.empty, emptyScopes))
          . getInfer

runInfer :: (forall s. Infer s a) -> Either InferError a
runInfer m = runST (runInfer' m)

inferType :: Env -> NExpr -> Infer s [(Subst, Type)]
inferType env ex = do
  Judgment as cs t <- infer ex
  let unbounds = Set.fromList (As.keys as) `Set.difference`
                 Set.fromList (Env.keys env)
  unless (Set.null unbounds) $
      typeError $ UnboundVariables (nub (Set.toList unbounds))
  let cs' = [ ExpInstConst t s
            | (x, ss) <- Env.toList env
            , s <- ss
            , t <- As.lookup x as]
  inferState <- get
  let eres = (`evalState` inferState) $ runSolver $ do
          subst <- solve (cs ++ cs')
          return (subst, subst `apply` t)
  case eres of
      Left errs -> throwError $ TypeInferenceErrors errs
      Right xs  -> pure xs

-- | Solve for the toplevel type of an expression in a given environment
inferExpr :: Env -> NExpr -> Either InferError [Scheme]
inferExpr env ex = case runInfer (inferType env ex) of
  Left err -> Left err
  Right xs -> Right $ map (\(subst, ty) -> closeOver (subst `apply` ty)) xs

-- | Canonicalize and return the polymorphic toplevel type.
closeOver :: Type -> Scheme
closeOver = normalize . generalize Set.empty

extendMSet :: TVar -> Infer s a -> Infer s a
extendMSet x = Infer . local (first (Set.insert x)) . getInfer

letters :: [String]
letters = [1..] >>= flip replicateM ['a'..'z']

fresh :: MonadState InferState m => m Type
fresh = do
    s <- get
    put s{count = count s + 1}
    return $ TVar $ TV (letters !! count s)

instantiate :: MonadState InferState m => Scheme -> m Type
instantiate (Forall as t) = do
    as' <- mapM (const fresh) as
    let s = Subst $ Map.fromList $ zip as as'
    return $ apply s t

generalize :: Set.Set TVar -> Type -> Scheme
generalize free t  = Forall as t
    where as = Set.toList $ ftv t `Set.difference` free

unops :: Type -> NUnaryOp -> [Constraint]
unops u1 = \case
    NNot -> [ EqConst u1 (typeFun [typeBool, typeBool]) ]
    NNeg -> [ EqConst u1 (TMany [ typeFun [typeInt,   typeInt]
                               , typeFun [typeFloat, typeFloat] ]) ]

binops :: Type -> NBinaryOp -> [Constraint]
binops u1 = \case
    NApp    -> []                -- this is handled separately

    -- Equality tells you nothing about the types, because any two types are
    -- allowed.
    NEq     -> []
    NNEq    -> []

    NGt     -> inequality
    NGte    -> inequality
    NLt     -> inequality
    NLte    -> inequality

    NAnd    -> [ EqConst u1 (typeFun [typeBool, typeBool, typeBool]) ]
    NOr     -> [ EqConst u1 (typeFun [typeBool, typeBool, typeBool]) ]
    NImpl   -> [ EqConst u1 (typeFun [typeBool, typeBool, typeBool]) ]

    NConcat -> [ EqConst u1 (TMany [ typeFun [typeList,   typeList,   typeList]
                                  , typeFun [typeList,   typeNull,   typeList]
                                  , typeFun [typeNull,   typeList,   typeList]
                                  ]) ]

    NUpdate -> [ EqConst u1 (TMany [ typeFun [typeSet,    typeSet,    typeSet]
                                  , typeFun [typeSet,    typeNull,   typeSet]
                                  , typeFun [typeNull,   typeSet,    typeSet]
                                  ]) ]

    NPlus   -> [ EqConst u1 (TMany [ typeFun [typeInt,    typeInt,    typeInt]
                                  , typeFun [typeFloat,  typeFloat,  typeFloat]
                                  , typeFun [typeInt,    typeFloat,  typeFloat]
                                  , typeFun [typeFloat,  typeInt,    typeFloat]
                                  , typeFun [typeString, typeString, typeString]
                                  , typeFun [typePath,   typePath,   typePath]
                                  , typeFun [typeString, typeString, typePath]
                                  ]) ]
    NMinus  -> arithmetic
    NMult   -> arithmetic
    NDiv    -> arithmetic
  where
    inequality =
        [ EqConst u1 (TMany [ typeFun [typeInt,    typeInt,    typeBool]
                            , typeFun [typeFloat,  typeFloat,  typeBool]
                            , typeFun [typeInt,    typeFloat,  typeBool]
                            , typeFun [typeFloat,  typeInt,    typeBool]
                            ]) ]

    arithmetic =
        [ EqConst u1 (TMany [ typeFun [typeInt,    typeInt,    typeInt]
                            , typeFun [typeFloat,  typeFloat,  typeFloat]
                            , typeFun [typeInt,    typeFloat,  typeFloat]
                            , typeFun [typeFloat,  typeInt,    typeFloat]
                            ]) ]

instance MonadVar (Infer s) where
    type Var (Infer s) = STRef s

    newVar x     = Infer $ lift $ lift $ lift $ newSTRef x
    readVar x    = Infer $ lift $ lift $ lift $ readSTRef x
    writeVar x y = Infer $ lift $ lift $ lift $ writeSTRef x y
    atomicModifyVar x f = Infer $ lift $ lift $ lift $ do
        res <- snd . f <$> readSTRef x
        _ <- modifySTRef x (fst . f)
        return res

newtype JThunk s = JThunk (Thunk (Infer s) (Judgment s))

instance MonadThrow (Infer s) where
    throwM = throwError . EvaluationError

instance MonadCatch (Infer s) where
    catch m h = catchError m $ \case
        EvaluationError e ->
            maybe (error $ "Exception was not an exception: " ++ show e) h
                  (fromException (toException e))
        err -> error $ "Unexpected error: " ++ show err

instance MonadThunk (Judgment s) (JThunk s) (Infer s) where
    thunk = fmap JThunk . buildThunk

    force (JThunk t) f = catch (forceThunk t f) $ \(_ :: ThunkLoop) ->
        -- If we have a thunk loop, we just don't know the type.
        f =<< Judgment As.empty [] <$> fresh

    value = JThunk . valueRef

instance MonadEval (Judgment s) (Infer s) where
    freeVariable var = do
        tv <- fresh
        return $ Judgment (As.singleton var tv) [] tv

    -- If we fail to look up an attribute, we just don't know the type.
    attrMissing _ _ = Judgment As.empty [] <$> fresh

    evaledSym _ = pure

    evalCurPos =
        return $ Judgment As.empty [] $ TSet False $ M.fromList
            [ ("file", typePath)
            , ("line", typeInt)
            , ("col",  typeInt) ]

    evalConstant c  = return $ Judgment As.empty [] (go c)
      where
        go = \case
          NInt _   -> typeInt
          NFloat _ -> typeFloat
          NBool _  -> typeBool
          NNull    -> typeNull

    evalString      = const $ return $ Judgment As.empty [] typeString
    evalLiteralPath = const $ return $ Judgment As.empty [] typePath
    evalEnvPath     = const $ return $ Judgment As.empty [] typePath

    evalUnary op (Judgment as1 cs1 t1) = do
        tv <- fresh
        return $ Judgment as1 (cs1 ++ unops (t1 `TArr` tv) op) tv

    evalBinary op (Judgment as1 cs1 t1) e2 = do
        Judgment as2 cs2 t2 <- e2
        tv <- fresh
        return $ Judgment
            (as1 `As.merge` as2)
            (cs1 ++ cs2 ++ binops (t1 `TArr` (t2 `TArr` tv)) op)
            tv

    evalWith = Eval.evalWithAttrSet

    evalIf (Judgment as1 cs1 t1) t f = do
        Judgment as2 cs2 t2 <- t
        Judgment as3 cs3 t3 <- f
        return $ Judgment
            (as1 `As.merge` as2 `As.merge` as3)
            (cs1 ++ cs2 ++ cs3 ++ [EqConst t1 typeBool, EqConst t2 t3])
            t2

    evalAssert (Judgment as1 cs1 t1) body = do
        Judgment as2 cs2 t2 <- body
        return $ Judgment
            (as1 `As.merge` as2)
            (cs1 ++ cs2 ++ [EqConst t1 typeBool])
            t2

    evalApp (Judgment as1 cs1 t1) e2 = do
        Judgment as2 cs2 t2 <- e2
        tv <- fresh
        return $ Judgment
            (as1 `As.merge` as2)
            (cs1 ++ cs2 ++ [EqConst t1 (t2 `TArr` tv)])
            tv

    evalAbs (Param x) k = do
        tv@(TVar a) <- fresh
        ((), Judgment as cs t) <-
            extendMSet a (k (pure (Judgment (As.singleton x tv) [] tv))
                            (\_ b -> ((),) <$> b))
        return $ Judgment
            (as `As.remove` x)
            (cs ++ [EqConst t' tv | t' <- As.lookup x as])
            (tv `TArr` t)

    evalAbs (ParamSet ps variadic _mname) k = do
        js <- fmap concat $ forM ps $ \(name, _) -> do
                tv <- fresh
                pure [(name, tv)]

        let (env, tys) = (\f -> foldl' f (As.empty, M.empty) js)
                $ \(as1, t1) (k, t) ->
                    (as1 `As.merge` As.singleton k t, M.insert k t t1)
            arg   = pure $ Judgment env [] (TSet True tys)
            call  = k arg $ \args b -> (args,) <$> b
            names = map fst js

        (args, Judgment as cs t) <-
            foldr (\(_, TVar a) -> extendMSet a) call js

        ty <- TSet variadic <$> traverse (inferredType <$>) args

        return $ Judgment
            (foldl' As.remove as names)
            (cs ++ [ EqConst t' (tys M.! x)
                   | x  <- names
                   , t' <- As.lookup x as])
            (ty `TArr` t)

    evalError = throwError . EvaluationError

data Judgment s = Judgment
    { assumptions     :: As.Assumption
    , typeConstraints :: [Constraint]
    , inferredType    :: Type
    }
    deriving Show

instance FromValue (Text, DList Text) (Infer s) (Judgment s) where
    fromValueMay _ = return Nothing
    fromValue _ = error "Unused"

instance FromValue (AttrSet (JThunk s), AttrSet SourcePos) (Infer s) (Judgment s) where
    fromValueMay (Judgment _ _ (TSet _ xs)) = do
        let sing _ = Judgment As.empty []
        pure $ Just (M.mapWithKey (\k v -> value (sing k v)) xs, M.empty)
    fromValueMay _ = pure Nothing
    fromValue = fromValueMay >=> \case
        Just v  -> pure v
        Nothing -> pure (M.empty, M.empty)

instance ToValue (AttrSet (JThunk s), AttrSet SourcePos) (Infer s) (Judgment s) where
    toValue (xs, _) = Judgment
        <$> foldrM go As.empty xs
        <*> (concat <$> traverse (`force` (pure . typeConstraints)) xs)
        <*> (TSet True <$> traverse (`force` (pure . inferredType)) xs)
      where
        go x rest = force x $ \x' -> pure $ As.merge (assumptions x') rest

instance ToValue [JThunk s] (Infer s) (Judgment s) where
    toValue xs = Judgment
        <$> foldrM go As.empty xs
        <*> (concat <$> traverse (`force` (pure . typeConstraints)) xs)
        <*> (TList <$> traverse (`force` (pure . inferredType)) xs)
      where
        go x rest = force x $ \x' -> pure $ As.merge (assumptions x') rest

instance ToValue Bool (Infer s) (Judgment s) where
    toValue _ = pure $ Judgment As.empty [] typeBool

infer :: NExpr -> Infer s (Judgment s)
infer = cata Eval.eval

inferTop :: Env -> [(Text, NExpr)] -> Either InferError Env
inferTop env [] = Right env
inferTop env ((name, ex):xs) = case inferExpr env ex of
  Left err -> Left err
  Right ty -> inferTop (extend env (name, ty)) xs

normalize :: Scheme -> Scheme
normalize (Forall _ body) = Forall (map snd ord) (normtype body)
  where
    ord = zip (nub $ fv body) (map TV letters)

    fv (TVar a)    = [a]
    fv (TArr a b)  = fv a ++ fv b
    fv (TCon _)    = []
    fv (TSet _ a)  = concatMap fv (M.elems a)
    fv (TList a)   = concatMap fv a
    fv (TMany ts)  = concatMap fv ts

    normtype (TArr a b)  = TArr (normtype a) (normtype b)
    normtype (TCon a)    = TCon a
    normtype (TSet b a)  = TSet b (M.map normtype a)
    normtype (TList a)   = TList (map normtype a)
    normtype (TMany ts)  = TMany (map normtype ts)
    normtype (TVar a)    =
      case Prelude.lookup a ord of
        Just x -> TVar x
        Nothing -> error "type variable not in signature"

-------------------------------------------------------------------------------
-- Constraint Solver
-------------------------------------------------------------------------------

newtype Solver m a = Solver (LogicT (StateT [TypeError] m) a)
    deriving (Functor, Applicative, Alternative, Monad, MonadPlus,
              MonadLogic, MonadState [TypeError])

instance MonadTrans Solver where
    lift = Solver . lift . lift

instance Monad m => MonadError TypeError (Solver m) where
    throwError err = Solver $ lift (modify (err:)) >> mzero
    catchError _ _ = error "This is never used"

runSolver :: Monad m => Solver m a -> m (Either [TypeError] [a])
runSolver (Solver s) = do
    res <- runStateT (observeAllT s) []
    pure $ case res of
        (x:xs, _) -> Right (x:xs)
        (_, es)   -> Left (nub es)

-- | The empty substitution
emptySubst :: Subst
emptySubst = mempty

-- | Compose substitutions
compose :: Subst -> Subst -> Subst
Subst s1 `compose` Subst s2 =
    Subst $ Map.map (apply (Subst s1)) s2 `Map.union` s1

unifyMany :: Monad m => [Type] -> [Type] -> Solver m Subst
unifyMany [] [] = return emptySubst
unifyMany (t1 : ts1) (t2 : ts2) =
  do su1 <- unifies t1 t2
     su2 <- unifyMany (apply su1 ts1) (apply su1 ts2)
     return (su2 `compose` su1)
unifyMany t1 t2 = throwError $ UnificationMismatch t1 t2

allSameType :: [Type] -> Bool
allSameType [] = True
allSameType [_] = True
allSameType (x:y:ys) = x == y && allSameType (y:ys)

unifies :: Monad m => Type -> Type -> Solver m Subst
unifies t1 t2 | t1 == t2 = return emptySubst
unifies (TVar v) t = v `bind` t
unifies t (TVar v) = v `bind` t
unifies (TList xs) (TList ys)
    | allSameType xs && allSameType ys = case (xs, ys) of
          (x:_, y:_) -> unifies x y
          _ -> return emptySubst
    | length xs == length ys = unifyMany xs ys
-- We assume that lists of different lengths containing various types cannot
-- be unified.
unifies t1@(TList _) t2@(TList _) = throwError $ UnificationFail t1 t2
unifies (TSet True _) (TSet True _) = return emptySubst
unifies (TSet False b) (TSet True s)
    | M.keys b `intersect` M.keys s == M.keys s = return emptySubst
unifies (TSet True s) (TSet False b)
    | M.keys b `intersect` M.keys s == M.keys b = return emptySubst
unifies (TSet False s) (TSet False b)
    | null (M.keys b \\ M.keys s) = return emptySubst
unifies (TArr t1 t2) (TArr t3 t4) = unifyMany [t1, t2] [t3, t4]
unifies (TMany t1s) t2 = considering t1s >>- unifies ?? t2
unifies t1 (TMany t2s) = considering t2s >>- unifies t1
unifies t1 t2 = throwError $ UnificationFail t1 t2

bind :: Monad m => TVar -> Type -> Solver m Subst
bind a t | t == TVar a     = return emptySubst
         | occursCheck a t = throwError $ InfiniteType a t
         | otherwise       = return (Subst $ Map.singleton a t)

occursCheck ::  FreeTypeVars a => TVar -> a -> Bool
occursCheck a t = a `Set.member` ftv t

nextSolvable :: [Constraint] -> (Constraint, [Constraint])
nextSolvable xs = fromJust (find solvable (chooseOne xs))
  where
    chooseOne xs = [(x, ys) | x <- xs, let ys = delete x xs]

    solvable (EqConst{}, _)      = True
    solvable (ExpInstConst{}, _) = True
    solvable (ImpInstConst _t1 ms t2, cs) =
        Set.null ((ftv t2 `Set.difference` ms) `Set.intersection` atv cs)

considering :: [a] -> Solver m a
considering xs = Solver $ LogicT $ \c n -> foldr c n xs

solve :: MonadState InferState m => [Constraint] -> Solver m Subst
solve [] = return emptySubst
solve cs = solve' (nextSolvable cs)
  where
    solve' (EqConst t1 t2, cs) =
      unifies t1 t2 >>- \su1 ->
      solve (apply su1 cs) >>- \su2 ->
          return (su2 `compose` su1)

    solve' (ImpInstConst t1 ms t2, cs) =
      solve (ExpInstConst t1 (generalize ms t2) : cs)

    solve' (ExpInstConst t s, cs) = do
      s' <- lift $ instantiate s
      solve (EqConst t s' : cs)