{-# LANGUAGE CPP             #-}
{-# LANGUAGE TupleSections   #-}
{-# LANGUAGE BangPatterns    #-}

module Infernu.InferState
       where

import           Data.Foldable              (msum)
import           Control.Monad              (foldM, forM, forM_, liftM2, when)
import           Control.Monad.Trans        (lift)
import           Control.Monad.Trans.Either (EitherT (..), left, runEitherT, bimapEitherT)
import           Control.Monad.Trans.State  (StateT (..), evalStateT, get, put, modify, mapStateT)
import qualified Data.Graph.Inductive      as Graph
    

import           Data.Functor.Identity      (Identity (..), runIdentity)
import qualified Data.Map.Lazy              as Map
-- import           Data.Map.Lazy              (Map)
import           Data.Maybe                 (fromMaybe)
import qualified Data.Set                   as Set
import           Data.Set                   (Set)
import           Prelude                    hiding (foldr, sequence, mapM)


import           Infernu.Prelude
import           Infernu.Pretty
import           Infernu.Types
import           Infernu.Log
import qualified Infernu.Builtins.TypeClasses
    
-- | Inference monad. Used as a stateful context for generating fresh type variable names.
type Infer a = StateT InferState (EitherT TypeError Identity) a

emptyInferState :: InferState
emptyInferState = InferState { nameSource = NameSource 2
                             , mainSubst = nullSubst
                             , varSchemes = Map.empty
                             , varInstances = Graph.empty
                             , namedTypes = Map.empty
                             , pendingUni = Set.empty
                             , classes = Map.fromList Infernu.Builtins.TypeClasses.typeClasses
                             }

runInferWith :: InferState -> Infer a -> Either TypeError a
runInferWith ns inf = runIdentity . runEitherT $ evalStateT inf ns
    -- where inf' = do res <- inf
    --                 unis <- getPendingUnifications
    --                 forM unis $ \((a, scheme), t) ->
    --                                 do inst <- instantiate scheme --  >>= assertNoPred
    --                                    unify a inst t

runSubInfer :: Infer a -> Infer (Either TypeError a)
runSubInfer a = do
  s <- get
  return $ runInferWith s a

getState :: Infer InferState
getState = get

setState :: InferState -> Infer ()
setState = put
           
runInfer :: Infer a -> Either TypeError a
runInfer = runInferWith emptyInferState

fresh :: Infer TVarName
fresh = do
  modify $ \is -> is { nameSource = (nameSource is) { lastName = lastName (nameSource is) + 1 } }
  lastName . nameSource <$> get

freshVarId :: Infer VarId
freshVarId = VarId <$> fresh

throwError :: Source -> String -> Infer a
throwError p s = lift . left $ TypeError p s

failWith :: Maybe a -> Infer a -> Infer a
failWith action err = case action of
                          Nothing -> err
                          Just x -> return x

failWithM :: Infer (Maybe a) -> Infer a -> Infer a
failWithM action err = do
  result <- action
  failWith result err

mapError :: (TypeError -> TypeError) -> Infer a -> Infer a
mapError f ma = mapStateT (bimapEitherT f id) ma
           
getVarSchemeByVarId :: VarId -> Infer (Maybe TypeScheme)
getVarSchemeByVarId varId = Map.lookup varId . varSchemes <$> get

getVarId :: EVarName -> TypeEnv -> Maybe VarId
getVarId = Map.lookup

getVarScheme :: Source -> EVarName -> TypeEnv -> Infer (Maybe TypeScheme)
getVarScheme a n env = case getVarId n env of
                       Nothing -> throwError a $ "Unbound variable: '" ++ show n ++ "'"
                       Just varId -> getVarSchemeByVarId varId

setVarScheme :: TypeEnv -> EVarName -> TypeScheme -> VarId -> Infer TypeEnv
setVarScheme env n scheme varId = do
  modify $ \is -> is { varSchemes = trace ("Inserting scheme for " ++ pretty n ++ ": " ++ pretty scheme) . Map.insert varId scheme $ varSchemes is }
  return $ Map.insert n varId env

addVarScheme :: TypeEnv -> EVarName -> TypeScheme -> Infer TypeEnv
addVarScheme env n scheme = do
  varId <- tracePretty ("-- '" ++ pretty n ++ "' = varId") <$> freshVarId
  setVarScheme env n scheme varId

--addPendingUnification :: (Source, Type, (ClasSet TypeScheme) -> Infer ()
addPendingUnification :: (Source, Type, (ClassName, Set TypeScheme)) -> Infer ()
addPendingUnification ts = do
    modify $ \is -> is { pendingUni = Set.insert ts $ pendingUni is }
    return ()

getPendingUnifications :: Infer (Set (Source, Type, (ClassName, Set TypeScheme)))
getPendingUnifications = pendingUni <$> get

setPendingUnifications :: (Set (Source, Type, (ClassName, Set TypeScheme))) -> Infer ()
setPendingUnifications ts = do
    modify $ \is -> is { pendingUni = ts }
    return ()
        
----------------------------------------------------------------------


addVarInstance :: TVarName -> TVarName -> Infer ()
addVarInstance x y = modify $ \is -> is { varInstances = tracePretty "updated equivs" $ addEquivalence x y (varInstances is) }

getFreeTVars :: TypeEnv -> Infer (Set TVarName)
getFreeTVars env = do
  let collectFreeTVs s varId = Set.union s <$> curFreeTVs
          where curFreeTVs = tr . maybe Set.empty freeTypeVars <$> getVarSchemeByVarId varId
                tr = tracePretty $ " collected from " ++ pretty varId ++ " free type variables: "
  foldM collectFreeTVs Set.empty (Map.elems env)

addNamedType :: TypeId -> Type -> TypeScheme -> Infer ()
addNamedType tid t scheme = do
  traceLog ("===> Introducing named type: " ++ pretty tid ++ " => " ++ pretty scheme)
  modify $ \is -> is { namedTypes = Map.insert tid (t, scheme) $ namedTypes is }
  return ()

-- | Compares schemes up to alpha equivalence including named type constructors equivalence (TCons
-- TName...).
-- 
-- >>> let mkNamedType tid ts = Fix $ TCons (TName (TypeId tid)) ts
--  
-- >>> areEquivalentNamedTypes (mkNamedType 0 [], schemeEmpty (Fix $ TBody TNumber)) (mkNamedType 1 [], schemeEmpty (Fix $ TBody TString))
-- False
-- >>> areEquivalentNamedTypes (mkNamedType 0 [], schemeEmpty (mkNamedType 0 [])) (mkNamedType 1 [], schemeEmpty (mkNamedType 1 []))
-- True
-- >>> :{
--     areEquivalentNamedTypes (mkNamedType 0 [], schemeEmpty (Fix $ TFunc [Fix $ TBody TNumber] (mkNamedType 0 [])))
--                             (mkNamedType 1 [], schemeEmpty (Fix $ TFunc [Fix $ TBody TNumber] (mkNamedType 1 [])))
-- :}
-- True
--  
-- >>> :{
--     areEquivalentNamedTypes (mkNamedType 0 [Fix $ TBody $ TVar 10], TScheme [10] (qualEmpty $ Fix $ TFunc [Fix $ TBody $ TVar 10] (mkNamedType 0 [Fix $ TBody $ TVar 10])))
--                             (mkNamedType 1 [Fix $ TBody $ TVar 11], TScheme [11] (qualEmpty $ Fix $ TFunc [Fix $ TBody $ TVar 11] (mkNamedType 1 [Fix $ TBody $ TVar 11])))
-- :}
-- True
areEquivalentNamedTypes :: (Type, TypeScheme) -> (Type, TypeScheme) -> Bool
areEquivalentNamedTypes (t1, s1) (t2, s2) = s2 == (s2 { schemeType = applySubst subst $ replaceFixQual (unFix t1) (unFix t2) $ schemeType s1 })
  where subst = foldr (\(x,y) s -> singletonSubst x (Fix $ TBody $ TVar y) `composeSubst` s) nullSubst $ zip (schemeVars s1) (schemeVars s2)

-- | Returns a TQual with the `src` type replaced everywhere with the `dest` type.
replaceFixQual :: (Functor f, Eq (f (Fix f))) => f (Fix f) -> f (Fix f) -> TQual (Fix f) -> TQual (Fix f)
replaceFixQual src dest (TQual preds t) = TQual (map (replacePredType' $ replaceFix src dest) preds) (replaceFix src dest t)
    where replacePredType' f p = p { predType = f $ predType p } -- TODO needs some lens goodness
                
-- Checks if a given type variable appears in the given type *only* as a parameter to a recursive
-- type name.  If yes, returns the name of recursive types (and position within) in which it
-- appears; otherwise returns Nothing.
--isRecParamOnly :: TVarName -> Maybe (TypeId, Int) -> Type -> Maybe [(TypeId, Int)]
isRecParamOnly
  :: (Num t, Enum t) =>
     TVarName -> Maybe (TypeId, t) -> Type -> Maybe [(TypeId, t)]
isRecParamOnly n1 typeId t1 =
  case unFix t1 of
   TBody (TVar n1') -> if n1' == n1 then sequence [typeId] else Just []
   TBody _ -> Just []
   TCons (TName typeId') subTs -> recurseIntoNamedType typeId' subTs
   TCons _ subTs -> msum $ map (isRecParamOnly n1 Nothing) subTs
   TFunc ts tres -> (isRecParamOnly n1 Nothing) tres `mappend` msum (map (isRecParamOnly n1 Nothing) ts)
   TRow rlist -> isRecParamRecList n1 rlist
     where isRecParamRecList n' rlist' =
             case rlist' of
              TRowEnd _ -> Just []
              -- TODO: assumes the quanitified vars in TScheme do not shadow other type variable names
              -- TODO: Can we safely ignore preds here?
              TRowProp _ (TScheme _ t') rlist'' -> liftM2 (++) (isRecParamOnly n1 Nothing $ qualType t') (isRecParamRecList n' rlist'')
              TRowRec typeId' subTs -> recurseIntoNamedType typeId' subTs
  where recurseIntoNamedType typeId' subTs = msum $ map (\(t,i) -> isRecParamOnly n1 (Just (typeId', i)) t) $ zip subTs [0..]

dropAt :: Integral a => a -> [b] -> [b]
dropAt _ [] = []
dropAt 0 (_:xs) = xs
dropAt n (_:xs) = dropAt (n-1) xs

replaceRecType :: TypeId -> TypeId -> Int -> Type -> Type
replaceRecType typeId newTypeId indexToDrop t1 =
    let replace' = replaceRecType typeId newTypeId indexToDrop
        mapTs' = map replace'
    in  case unFix t1 of
            TBody _ -> t1
            TCons (TName typeId') subTs -> if typeId == typeId'
                                          then Fix $ TCons (TName newTypeId) $ dropAt indexToDrop subTs
                                          else t1
            TCons n subTs -> Fix $ TCons n $ mapTs' subTs
            TFunc ts tres -> Fix $ TFunc (mapTs' ts) (replace' tres)
            TRow rlist -> Fix $ TRow $ go rlist
             where go rlist' =
                     case rlist' of
                      TRowEnd _ -> rlist'
                      TRowProp p (TScheme qv t') rlist'' -> TRowProp p (TScheme qv (t' { qualType = replace' $ qualType t' })) (go rlist'')
                      TRowRec tid ts -> if typeId == tid
                                        then TRowRec newTypeId $ dropAt indexToDrop ts
                                        else rlist'

allocNamedType :: TVarName -> Type -> Infer Type
allocNamedType n t =
  do typeId <- TypeId <$> fresh
     let namedType = TCons (TName typeId) $ map (Fix . TBody . TVar) $ Set.toList $ freeTypeVars t `Set.difference` Set.singleton n
         target = replaceFix (TBody (TVar n)) namedType t
     scheme <- unsafeGeneralize Map.empty $ qualEmpty target
     traceLog $ "===> Generated scheme for mu type: " ++ pretty scheme
     currentNamedTypes <- filter (areEquivalentNamedTypes (Fix namedType, scheme)) . map snd . Map.toList . namedTypes <$> get
     case currentNamedTypes of
      [] -> do addNamedType typeId (Fix namedType) scheme
               return $ Fix namedType
      (otherNT, _):_ -> return otherNT

resolveSimpleMutualRecursion :: TVarName -> Type -> TypeId -> Int -> Infer Type
resolveSimpleMutualRecursion n t tid ix =
  do (Fix (TCons (TName _) ts), scheme) <- (Map.lookup tid . namedTypes <$> get) `failWithM` error "oh no." -- TODO
     newTypeId <- TypeId <$> fresh
     let qVars' = dropAt ix $ schemeVars scheme
         replaceOldNamedType = replaceRecType tid newTypeId ix
         sType' = (schemeType scheme) { qualType = replaceOldNamedType $ qualType $  schemeType scheme }
         newTs = dropAt ix $ ts
         newNamedType = Fix (TCons (TName newTypeId) newTs)
         --updatedNamedType = Fix (TCons (TName tid) newTs)
         updatedScheme = applySubst (singletonSubst n newNamedType) $ TScheme qVars'  sType'
         
     addNamedType newTypeId newNamedType updatedScheme
     -- TODO: we could alternatively update the existing named type, but that will break it's schema (will now take less params?)
     --addNamedType tid updatedNamedType updatedScheme
     return $ replaceOldNamedType t
     
     
getNamedType :: TVarName -> Type -> Infer Type
getNamedType n t =
  do let recTypeParamPos = isRecParamOnly n Nothing t
     traceLog ("isRecParamOnly: " ++ pretty n ++ " in " ++ pretty t ++ ": " ++ (show $ fmap pretty $ recTypeParamPos))
     case recTypeParamPos of
      Just [(tid, ix)] -> resolveSimpleMutualRecursion n t tid ix
      -- either the variable appears outside a recursive type's type parameter list, or it appears
      -- in more than one such position:
      _ -> allocNamedType n t 


unrollNameByScheme :: Substable a => [Type] -> [TVarName] -> a -> a
unrollNameByScheme ts qvars t = applySubst subst t
  where assocs = zip qvars ts
        subst = foldr (\(tvar,destType) s -> singletonSubst tvar destType `composeSubst` s) nullSubst assocs

-- | Unrolls (expands) a TName recursive type by plugging in the holes from the given list of types.
-- Similar to instantiation, but uses a pre-defined set of type instances instead of using fresh
-- type variables.
unrollName :: Source -> TypeId -> [Type] -> Infer QualType
unrollName a tid ts =
    -- TODO: Is it safe to ignore the scheme preds here?
    do (TScheme qvars t) <- (fmap snd . Map.lookup tid . namedTypes <$> get) `failWithM` throwError a "Unknown type id"
       return $ unrollNameByScheme ts qvars t
    
-- | Applies a subsitution onto the state (basically on the variable -> scheme map).
--
-- >>> :{
-- runInfer $ do
--     let t = TScheme [0] (TQual [] (Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 1))))
--     let tenv = Map.empty
--     tenv' <- addVarScheme tenv "x" t
--     applySubstInfer $ Map.singleton 0 (Fix $ TBody TString)
--     varSchemes <$> get
-- :}
-- Right (fromList [(VarId 3,TScheme {schemeVars = [], schemeType = TQual {qualPred = [], qualType = Fix (TFunc [Fix (TBody TString)] Fix (TBody (TVar 1)))}})])
--
applySubstInfer :: TSubst -> Infer ()
applySubstInfer s =
  do traceLog ("applying subst: " ++ pretty s)
     modify $ applySubst s

-- | Instantiate a type scheme by giving fresh names to all quantified type variables.
--
-- For example:
--
-- >>> runInferWith (emptyInferState { nameSource = NameSource 2 }) . instantiate $ TScheme [0] (TQual { qualPred = [], qualType = Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 1)) }) 
-- Right (TQual {qualPred = [], qualType = Fix (TFunc [Fix (TBody (TVar 3))] Fix (TBody (TVar 1)))})
--
-- In the above example, type variable 0 has been replaced with a fresh one (3), while the unqualified free type variable 1 has been left as-is.
--
-- >>> :{
-- runInfer $ do
--     let t = TScheme [0] (TQual [] (Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 1))))
--     let tenv = Map.empty
--     tenv' <- addVarScheme tenv "x" t
--     instantiateVar emptySource "x" tenv'
-- :}
-- Right (TQual {qualPred = [], qualType = Fix (TFunc [Fix (TBody (TVar 4))] Fix (TBody (TVar 1)))})
--
instantiateScheme :: Bool -> TypeScheme -> Infer QualType
instantiateScheme shouldAddVarInstances (TScheme tvarNames t) = do
  allocNames <- forM tvarNames $ \tvName -> do
    freshName <- fresh
    return (tvName, freshName)
  when shouldAddVarInstances $ forM_ allocNames $ uncurry addVarInstance
  let replaceVar n = fromMaybe n $ lookup n allocNames
  return $ mapVarNames replaceVar t

instantiate :: TypeScheme -> Infer QualType
instantiate = instantiateScheme True

instantiateVar :: Source -> EVarName -> TypeEnv -> Infer QualType
instantiateVar a n env = do
  varId <- getVarId n env `failWith` throwError a ("Unbound variable: '" ++ show n ++ "'")
  scheme <- getVarSchemeByVarId varId `failWithM` throwError a ("Assertion failed: missing var scheme for: '" ++ show n ++ "'")
  tracePretty ("Instantiated var '" ++ pretty n ++ "' with scheme: " ++ pretty scheme ++ " to") <$> instantiate scheme

----------------------------------------------------------------------
-- | Generalizes a type to a type scheme, i.e. wraps it in a "forall" that quantifies over all
--   type variables that are free in the given type, but are not free in the type environment.
--
-- Example:
--
-- >>> runInfer $ generalize (ELit "bla" LitUndefined) Map.empty $ qualEmpty $ Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 1))
-- Right (TScheme {schemeVars = [0,1], schemeType = TQual {qualPred = [], qualType = Fix (TFunc [Fix (TBody (TVar 0))] Fix (TBody (TVar 1)))}})
--
-- >>> :{
-- runInfer $ do
--     let t = TScheme [1] (TQual [] (Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 1))))
--     tenv <- addVarScheme Map.empty "x" t
--     generalize (ELit "bla" LitUndefined) tenv (qualEmpty $ Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 2)))
-- :}
-- Right (TScheme {schemeVars = [2], schemeType = TQual {qualPred = [], qualType = Fix (TFunc [Fix (TBody (TVar 0))] Fix (TBody (TVar 2)))}})
--
-- In this example the steps were:
--
-- 1. Environment: { x :: forall 0. 0 -> 1 }
--
-- 2. generalize (1 -> 2)
--
-- 3. result: forall 2. 1 -> 2
--
-- >>> let expr = ELit "foo" LitUndefined
--
-- >>> runInfer $ generalize expr Map.empty (qualEmpty $ Fix $ TFunc [Fix $ TBody (TVar 0)] (Fix $ TBody (TVar 0)))
-- Right (TScheme {schemeVars = [0], schemeType = TQual {qualPred = [], qualType = Fix (TFunc [Fix (TBody (TVar 0))] Fix (TBody (TVar 0)))}})
--
-- >>> runInfer $ generalize expr Map.empty (TQual [TPredIsIn (ClassName "Bla") (Fix $ TBody (TVar 1))] (Fix $ TBody (TVar 0)))
-- Right (TScheme {schemeVars = [0,1], schemeType = TQual {qualPred = [TPredIsIn {predClass = ClassName "Bla", predType = Fix (TBody (TVar 1))}], qualType = Fix (TBody (TVar 0))}})
--
-- TODO add tests for monotypes
unsafeGeneralize :: TypeEnv -> QualType -> Infer TypeScheme
unsafeGeneralize tenv t = do
    traceLog $ "Generalizing: " ++ pretty t
    s <- getMainSubst
    let t' = applySubst s t
    unboundVars <- Set.difference (freeTypeVars t') <$> getFreeTVars tenv
    traceLog $ "Generalization result: unbound vars = " ++ pretty unboundVars ++ ", type = " ++ pretty t'
    return $ TScheme (Set.toList unboundVars) t'

isExpansive :: Exp a -> Bool
isExpansive (EVar _ _)        = False
isExpansive (EApp _ _ _)      = True
isExpansive (EAssign _ _ _ _) = True
isExpansive (EPropAssign _ _ _ _ _) = True
isExpansive (EIndexAssign _ _ _ _ _) = True
isExpansive (ELet _ _ _ _)    = True
isExpansive (EAbs _ _ _)      = False
isExpansive (ELit _ _)        = False
isExpansive (EArray _ _)  = True
isExpansive (ETuple _ _)  = True
isExpansive (EStringMap _ _)    = True
isExpansive (ERow _ _ _)    = True
isExpansive (ECase _ ep es) = any isExpansive (ep:map snd es)
isExpansive (EProp _ e _)  = isExpansive e
isExpansive (EIndex _ a b)  = any isExpansive [a, b]
isExpansive (ENew _ _ _) = True


generalize :: Exp a -> TypeEnv -> QualType -> Infer TypeScheme
generalize exp' env t = if isExpansive exp'
                        then return $ TScheme [] t
                        else unsafeGeneralize env t

minifyVarsFunc :: (VarNames a) => a -> TVarName -> TVarName
minifyVarsFunc xs n = fromMaybe n $ Map.lookup n vars
  where vars = Map.fromList $ zip (Set.toList $ freeTypeVars xs) ([0..] :: [TVarName])

minifyVars :: (VarNames a) => a -> a
minifyVars xs = mapVarNames (minifyVarsFunc xs) xs

getVarInstances :: Infer (Graph.Gr QualType ())
getVarInstances = varInstances <$> get


getMainSubst :: Infer TSubst
getMainSubst = mainSubst <$> get

applyMainSubst :: Substable b => b -> Infer b
applyMainSubst x =
  do s <- getMainSubst
     return $ applySubst s x


substVar :: TSubst -> TVarName -> TVarName
substVar subst x = let varX = Fix (TBody (TVar x))
                   in case applySubst subst varX of
                          Fix (TBody (TVar zx)) -> zx
                          _ -> x

lookupClass :: ClassName -> Infer (Maybe (Class Type))
lookupClass cs = Map.lookup cs . classes <$> get