{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Note: this module should NOT export externals. It is for common
--   transformations needed by the other primitive modules.
module HERMIT.Dictionary.Common
    ( -- * Utility Transformations
      applyInContextT
      -- ** Finding function calls.
    , callT
    , callPredT
    , callNameT
    , callSaturatedT
    , callNameG
    , callDataConT
    , callDataConNameT
      -- ** Collecting variable bindings
    , progConsIdsT
    , progConsRecIdsT
    , progConsNonRecIdT
    , nonRecVarT
    , recIdsT
    , lamVarT
    , letVarsT
    , letRecIdsT
    , letNonRecVarT
    , caseVarsT
    , caseBinderIdT
    , caseAltVarsT
      -- ** Finding variables bound in the Context
    , boundVarsT
    , findBoundVarT
    , findIdT
    , findVarT
    , findTyConT
    , findTypeT
    , varBindingDepthT
    , varIsOccurrenceOfT
    , exprIsOccurrenceOfT
    , withVarsInScope
      -- Miscellaneous
    , wrongExprForm
    )

where

import Data.List (nub)

import Control.Arrow
import Control.Monad.IO.Class

import HERMIT.Context
import HERMIT.Core
import HERMIT.GHC
import HERMIT.Kure
import HERMIT.Monad
import HERMIT.Name

import Prelude.Compat

------------------------------------------------------------------------------

-- | apply a transformation to a value in the current context.
applyInContextT :: Transform c m a b -> a -> Transform c m x b
applyInContextT t a = contextonlyT $ \ c -> applyT t c a

-- Note: this is the same as: return a >>> t

------------------------------------------------------------------------------

-- | Lift GHC's collectArgs
callT :: Monad m => Transform c m CoreExpr (CoreExpr, [CoreExpr])
callT = contextfreeT $ \ e -> case e of
                                Var {} -> return (e, [])
                                App {} -> return (collectArgs e)
                                _      -> fail "not an application or variable occurence."

-- | Succeeds if we are looking at an application matching the given predicate.
callPredT :: Monad m => (Id -> [CoreExpr] -> Bool) -> Transform c m CoreExpr (CoreExpr, [CoreExpr])
callPredT p = do
    call@(Var i, args) <- callT
    guardMsg (p i args) "predicate failed."
    return call

-- | Succeeds if we are looking at an application of given function
--   returning zero or more arguments to which it is applied.
--
-- Note: comparison is performed with cmpHN2Var.
callNameT :: MonadCatch m => HermitName -> Transform c m CoreExpr (CoreExpr, [CoreExpr])
callNameT nm = prefixFailMsg ("callNameT failed: not a call to '" ++ show nm ++ ".")
             $ callPredT (const . cmpHN2Var nm)

-- | Succeeds if we are looking at a fully saturated function call.
callSaturatedT :: Monad m => Transform c m CoreExpr (CoreExpr, [CoreExpr])
callSaturatedT = callPredT (\ i args -> let (tvs, ty) = splitForAllTys (varType i)
                                            (bs,_) = splitFunTys ty
                                        in (length tvs + length bs) == length args)

-- | Succeeds if we are looking at an application of given function
callNameG :: MonadCatch m => HermitName -> Transform c m CoreExpr ()
callNameG nm = prefixFailMsg "callNameG failed: " $ callNameT nm >> return ()

-- | Succeeds if we are looking at an application of a data constructor.
callDataConT :: MonadCatch m => Transform c m CoreExpr (DataCon, [Type], [CoreExpr])
callDataConT = prefixFailMsg "callDataConT failed:" $
    do mb <- contextfreeT $ \ e -> let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- varSetElems (localFreeVarsExpr e) ])
                                   in return $ exprIsConApp_maybe (in_scope, idUnfolding) e
       maybe (fail "not a datacon application.") return mb

-- | Succeeds if we are looking at an application of a named data constructor.
callDataConNameT :: MonadCatch m => String -> Transform c m CoreExpr (DataCon, [Type], [CoreExpr])
callDataConNameT nm = do
    res@(dc,_,_) <- callDataConT
    guardMsg (cmpString2Name nm (dataConName dc)) "wrong datacon."
    return res

------------------------------------------------------------------------------

-- | List the identifiers bound by the top-level binding group at the head of the program.
progConsIdsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, MonadCatch m) => Transform c m CoreProg [Id]
progConsIdsT = progConsT (arr bindVars) mempty (\ vs () -> vs)

-- | List the identifiers bound by a recursive top-level binding group at the head of the program.
progConsRecIdsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreProg [Id]
progConsRecIdsT = progConsT recIdsT mempty (\ vs () -> vs)

-- | Return the identifier bound by a non-recursive top-level binding at the head of the program.
progConsNonRecIdT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreProg Id
progConsNonRecIdT = progConsT nonRecVarT mempty (\ v () -> v)

-- | Return the variable bound by a non-recursive let expression.
nonRecVarT :: (ExtendPath c Crumb, Monad m) => Transform c m CoreBind Var
nonRecVarT = nonRecT idR mempty (\ v () -> v)

-- | List all identifiers bound in a recursive binding group.
recIdsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreBind [Id]
recIdsT = recT (\ _ -> arr defId) id

-- | Return the variable bound by a lambda expression.
lamVarT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreExpr Var
lamVarT = lamT idR mempty (\ v () -> v)

-- | List the variables bound by a let expression.
letVarsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, MonadCatch m) => Transform c m CoreExpr [Var]
letVarsT = letT (arr bindVars) mempty (\ vs () -> vs)

-- | List the identifiers bound by a recursive let expression.
letRecIdsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreExpr [Id]
letRecIdsT = letT recIdsT mempty (\ vs () -> vs)

-- | Return the variable bound by a non-recursive let expression.
letNonRecVarT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreExpr Var
letNonRecVarT = letT nonRecVarT mempty (\ v () -> v)

-- | List all variables bound by a case expression (in the alternatives and the case binder).
caseVarsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreExpr [Var]
caseVarsT = caseT mempty idR mempty (\ _ -> arr altVars) (\ () v () vss -> v : nub (concat vss))

-- | Return the case binder.
caseBinderIdT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreExpr Id
caseBinderIdT = caseT mempty idR mempty (\ _ -> idR) (\ () i () _ -> i)

-- | List the variables bound by all alternatives in a case expression.
caseAltVarsT :: (ExtendPath c Crumb, ReadPath c Crumb, AddBindings c, Monad m) => Transform c m CoreExpr [[Var]]
caseAltVarsT = caseT mempty mempty mempty (\ _ -> arr altVars) (\ () () () vss -> vss)

------------------------------------------------------------------------------

-- | Find the depth of a variable's binding.
varBindingDepthT :: (ReadBindings c, Monad m) => Var -> Transform c m g BindingDepth
varBindingDepthT v = contextT >>= lookupHermitBindingDepth v

-- | Determine if the current variable matches the given variable, and is bound at the specified depth (helpful to detect shadowing).
varIsOccurrenceOfT :: (ExtendPath c Crumb, ReadBindings c, Monad m) => Var -> BindingDepth -> Transform c m Var Bool
varIsOccurrenceOfT v d = readerT $ \ v' -> if v == v'
                                             then varBindingDepthT v >>^ (== d)
                                             else return False

-- | Determine if the current expression is an occurrence of the given variable, bound at the specified depth (helpful to detect shadowing).
exprIsOccurrenceOfT :: (ExtendPath c Crumb, ReadBindings c, Monad m) => Var -> BindingDepth -> Transform c m CoreExpr Bool
exprIsOccurrenceOfT v d = varT $ varIsOccurrenceOfT v d

-- | Lifted version of 'boundVars'.
boundVarsT :: (BoundVars c, Monad m) => Transform c m a VarSet
boundVarsT = contextonlyT (return . boundVars)

-- | Find the unique variable bound in the context that matches the given name, failing if it is not unique.
findBoundVarT :: (BoundVars c, MonadCatch m) => (Var -> Bool) -> Transform c m a Var
findBoundVarT p = do
    c <- contextT
    case varSetElems (findBoundVars p c) of
        []         -> fail "no matching variables in scope."
        [v]        -> return v
        _ : _ : _  -> fail "multiple matching variables in scope."

--------------------------------------------------------------------------------------------------

-- | Lookup the name in the context first, then, failing that, in GHC's global reader environment.
findIdT :: (BoundVars c, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m)
        => HermitName -> Transform c m a Id
findIdT nm = prefixFailMsg ("Cannot resolve name " ++ show nm ++ ", ") $ contextonlyT (findId nm)

-- | Lookup the name in the context first, then, failing that, in GHC's global reader environment.
findVarT :: (BoundVars c, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m)
         => HermitName -> Transform c m a Var
findVarT nm = prefixFailMsg ("Cannot resolve name " ++ show nm ++ ", ") $ contextonlyT (findVar nm)

-- | Lookup the name in the context first, then, failing that, in GHC's global reader environment.
findTyConT :: (BoundVars c, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m)
           => HermitName -> Transform c m a TyCon
findTyConT nm = prefixFailMsg ("Cannot resolve name " ++ show nm ++ ", ") $ contextonlyT (findTyCon nm)

-- | Lookup the name in the context first, then, failing that, in GHC's global reader environment.
findTypeT :: (BoundVars c, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m)
          => HermitName -> Transform c m a Type
findTypeT nm = prefixFailMsg ("Cannot resolve name " ++ show nm ++ ", ") $ contextonlyT (findType nm)

-- | Modify transformation to apply to current expression as if it were the body of a lambda binding the given variables.
withVarsInScope :: (AddBindings c, ReadPath c Crumb) => [Var] -> Transform c m a b -> Transform c m a b
withVarsInScope vs t = transform $ applyT t . flip (foldl (flip addLambdaBinding)) vs -- careful to add left-to-right

------------------------------------------------------------------------------

-- | Constructs a common error message.
--   Argument 'String' should be the desired form of the expression.
wrongExprForm :: String -> String
wrongExprForm form = "Expression does not have the form: " ++ form

------------------------------------------------------------------------------