module Language.SequentCore.Lint ( lintCoreBindings, lintTerm ) where

import Language.SequentCore.Syntax
import Language.SequentCore.WiredIn

import Coercion     ( coercionKind, coercionType )
import DataCon
import Id
import Kind
import Literal
import Outputable
import Pair
import Type
import VarEnv

import Control.Monad
import Data.List    ( mapAccumL )

type LintM = Either SDoc
type LintEnv = TvSubst

eitherToMaybe :: Either a b -> Maybe a
eitherToMaybe (Left a)  = Just a
eitherToMaybe (Right _) = Nothing

lintCoreBindings :: [SeqCoreBind] -> Maybe SDoc
lintCoreBindings binds = eitherToMaybe $ foldM lintCoreBind emptyTvSubst binds

lintTerm :: TvSubst -> SeqCoreTerm -> Maybe SDoc
lintTerm env term = eitherToMaybe $ lintCoreTerm env term 

lintCoreBind :: LintEnv -> SeqCoreBind -> LintM LintEnv
lintCoreBind env (NonRec bndr rhs)
  = do
    let bndrTy = substTy env (idType bndr)
        bndr'  = bndr `setIdType` bndrTy
        env'   = extendTvInScope env bndr'
    case rhs of
      Cont cont -> do
                   contTy <- contIdTyOrError env bndr
                   lintCoreCont (text "in RHS for cont id" <+> ppr bndr)
                                env' contTy cont
      _         -> do
                   rhsTy <- lintCoreTerm env' rhs
                   checkRhsType bndr bndrTy rhsTy
    return env'
lintCoreBind env (Rec pairs)
  = do
    let bndrs   = map fst pairs
        bndrTys = map (substTy env . idType) bndrs
        bndrs'  = zipWith setIdType bndrs bndrTys
        env'    = extendTvInScopeList env bndrs'
    rhsTys <- mapM (lintCoreTerm env' . snd) pairs
    forM_ (zip3 bndrs bndrTys rhsTys) $ \(bndr, bndrTy, rhsTy) ->
      checkRhsType bndr bndrTy rhsTy
    return env'

lintCoreTerm :: LintEnv -> SeqCoreTerm -> LintM Type
lintCoreTerm env (Var x)
  | not (isLocalId x)
  = return (idType x)
  | Just x' <- lookupInScope (getTvInScope env) x
  = if substTy env (idType x) `eqType` idType x'
      then return $ idType x'
      else Left $ text "variable" <+> pprBndr LetBind x <+> text "bound as"
                                  <+> pprBndr LetBind x'
  | otherwise
  = Left $ text "not found in context:" <+> pprBndr LetBind x

lintCoreTerm env (Lam xs k comm)
  = do
    let (env', xs') = mapAccumL lintBind env xs
        (env'', k') = lintBind env' k
    lintCoreCommand env'' comm
    retTy <- contIdTyOrError env' k'
    return $ mkPiTypes xs' retTy
  where
    lintBind env x
      | isTyVar x
      = substTyVarBndr env x
      | otherwise
      = (env', x')
      where
        x' = substTyInId env x
        env' = extendTvInScope env x'

lintCoreTerm env (Cons dc args)
  = do
    let (tyVars, monoTy)  = splitForAllTys $ dataConRepType dc
        (argTys, resTy)   = splitFunTys monoTy
        (tyArgs, valArgs) = partitionTypes args
    unless (length valArgs == dataConRepArity dc) $
      Left (text "wrong number of args for" <+> ppr dc $$ ppr args)
    unless (length tyVars == length tyArgs) $
      Left (text "wrong number of type args for" <+> ppr dc $$ ppr args)
    let augment env' (tyVar, ty)
          = do
            let tyVarTy = substTy env' (idType tyVar)
                kind    = substTy env' (typeKind ty)
            unless (tyVarTy `eqType` kind) $
              mkError (text "kind of arg" <+> ppr ty <+> text "for" <+> ppr tyVar)
                (ppr tyVarTy) (ppr kind)
            let tyVar' = tyVar `setIdType` tyVarTy
                ty'    = substTy env' ty
            return $ extendTvSubst env' tyVar' ty' `extendTvInScope` tyVar'
    env' <- foldM augment env (zip tyVars tyArgs)
    let doArg argTy arg
          = do
            let argTy' = substTy env' argTy
            checkingType (ppr arg) argTy' $ lintCoreTerm env' arg
    zipWithM_ doArg argTys valArgs
    return $ substTy env' resTy

lintCoreTerm env (Compute bndr comm)
  = do
    ty <- contIdTyOrError env bndr
    lintCoreCommand env' comm
    return ty
  where
    env' = extendTvInScopeSubsted env bndr

lintCoreTerm _env (Lit lit)
  = return $ literalType lit

lintCoreTerm env (Type ty)
  = return $ typeKind (substTy env ty)

lintCoreTerm env (Coercion co)
  = return $ substTy env (coercionType co)

lintCoreTerm _env (Cont cont)
  = Left $ text "unexpected continuation as term:" <+> ppr cont

lintCoreCommand :: LintEnv -> SeqCoreCommand -> LintM ()
lintCoreCommand env (Command { cmdLet = binds, cmdTerm = term, cmdCont = cont })
  = do
    env' <- foldM lintCoreBind env binds
    lintCoreCut env' term cont

lintCoreCut :: LintEnv -> SeqCoreTerm -> SeqCoreCont -> LintM ()
lintCoreCut env term cont
  = do
    ty <- lintCoreTerm env term
    lintCoreCont (text "in continuation of" <+> ppr term) env ty cont

lintCoreCont :: SDoc -> LintEnv -> Type -> SeqCoreCont -> LintM ()
lintCoreCont desc env ty (Return k)
  | Just k' <- lookupInScope (getTvInScope env) k
  = if substTy env (idType k) `eqType` idType k'
      then void $ checkingType (desc <> colon <+> text "cont variable" <+> ppr k) ty $ contIdTyOrError env k
      else Left $ desc <> colon <+> text "cont variable" <+> pprBndr LetBind k <+> text "bound as"
                                                         <+> pprBndr LetBind k'
  | otherwise
  = Left $ text "not found in context:" <+> pprBndr LetBind k
lintCoreCont desc env ty (App (Type tyArg) cont)
  | Just (tyVar, resTy) <- splitForAllTy_maybe (substTy env ty)
  = do
    let tyArg' = substTy env tyArg
    if typeKind tyArg' `isSubKind` idType tyVar
      then do
           let env' = extendTvSubst env tyVar tyArg'
               -- Don't reapply the rest of the substitution; just apply the new thing
               resTy' = substTy (extendTvSubst emptyTvSubst tyVar tyArg') resTy
           lintCoreCont desc env' resTy' cont
      else mkError (desc <> colon <+> text "type argument" <+> ppr tyArg)
             (ppr (typeKind tyArg')) (ppr (idType tyVar))
  | otherwise
  = Left $ desc <> colon <+> text "not a forall type:" <+> ppr ty
lintCoreCont desc env ty (App arg cont)
  | Just (argTy, resTy) <- splitFunTy_maybe (substTy env ty)
  = do
    void $ checkingType (desc <> colon <+> ppr arg) argTy $ lintCoreTerm env arg
    lintCoreCont desc env resTy cont
  | otherwise
  = Left $ desc <> colon <+> text "not a function type:" <+> ppr ty
lintCoreCont desc env ty (Cast co cont)
  = do
    let Pair fromTy toTy = coercionKind co
        fromTy' = substTy env fromTy
        toTy'   = substTy env toTy
    void $ checkingType (desc <> colon <+> text "incoming type of" <+> ppr co) ty $ return fromTy'
    lintCoreCont desc env toTy' cont
lintCoreCont desc env ty (Tick _ cont)
  = lintCoreCont desc env ty cont
lintCoreCont desc env ty (Case bndr alts)
  = do
    let env' = extendTvInScopeSubsted env bndr
    forM_ alts $ \(Alt _ bndrs rhs) ->
      lintCoreCommand (extendTvInScopeListSubsted env' bndrs) rhs
    void $ checkingType (desc <> colon <+> text "type of case binder") ty $
      return $ substTy env (idType bndr)

extendTvInScopeSubsted :: TvSubst -> Var -> TvSubst
extendTvInScopeSubsted tvs var
  = extendTvInScope tvs (substTyInId tvs var)

substTyInId :: TvSubst -> Var -> Var
substTyInId tvs var = var `setIdType` substTy tvs (idType var)

extendTvInScopeListSubsted :: TvSubst -> [Var] -> TvSubst
extendTvInScopeListSubsted tvs vars
  = foldr (flip extendTvInScopeSubsted) tvs vars

mkError :: SDoc -> SDoc -> SDoc -> LintM ()
mkError desc ex act = Left (desc $$ text "expected:" <+> ex
                                 $$ text "actual:" <+> act)
  
checkRhsType :: Var -> Type -> Type -> LintM ()
checkRhsType bndr bndrTy rhsTy
  = unless (bndrTy `eqType` rhsTy) $
      mkError (text "type of RHS of" <+> ppr bndr) (ppr bndrTy) (ppr rhsTy)

checkingType :: SDoc -> Type -> LintM Type -> LintM Type
checkingType desc ex go
  = do
    act <- go
    unless (ex `eqType` act) $ mkError desc (ppr ex) (ppr act)
    return act

contIdTyOrError :: LintEnv -> ContId -> LintM Type
contIdTyOrError env k
  = case isContTy_maybe (substTy env (idType k)) of
      Just arg -> return arg
      _        -> Left (text "bad cont type:" <+> pprBndr LetBind k)