-- | Note: this module should NOT export externals. It is for common
--   transformations needed by the other primitive modules.
module Language.HERMIT.Primitive.Common
    ( -- * Utility Transformations
      -- ** Collecting variables bound at a Node
      progVarsT
    , bindVarsT
    , nonRecVarT
    , recVarsT
    , defVarT
    , lamVarT
    , letVarsT
    , letRecVarsT
    , letNonRecVarT
    , caseVarsT
    , caseWildVarT
    , caseAltVarsT
    , altVarsT
      -- ** Finding variables bound in the Context
    , boundVarsT
    , findBoundVarT
    , findIdT
      -- ** Error Message Generators
    , wrongExprForm
    )

where

import GhcPlugins

import Data.List
import Data.Monoid

import Language.HERMIT.Kure
import Language.HERMIT.Core
import Language.HERMIT.Context
import Language.HERMIT.GHC

import qualified Language.Haskell.TH as TH

------------------------------------------------------------------------------

-- | List all identifiers bound at the top-level in a program.
progVarsT :: TranslateH CoreProg [Id]
progVarsT = progNilT [] <+ progConsT bindVarsT progVarsT (++)

-- | List all identifiers bound in a binding group.
bindVarsT :: TranslateH CoreBind [Var]
bindVarsT = fmap return nonRecVarT <+ recVarsT

-- | Return the variable bound by a non-recursive let expression.
nonRecVarT :: TranslateH CoreBind Var
nonRecVarT = nonRecT mempty (\ v () -> v)

-- | List all identifiers bound in a recursive binding group.
recVarsT :: TranslateH CoreBind [Id]
recVarsT = recT (\ _ -> defVarT) id

-- | Return the identifier bound by a recursive definition.
defVarT :: TranslateH CoreDef Id
defVarT = defT mempty (\ v () -> v)

-- | Return the variable bound by a lambda expression.
lamVarT :: TranslateH CoreExpr Var
lamVarT = lamT mempty (\ v () -> v)

-- | List the variables bound by a let expression.
letVarsT :: TranslateH CoreExpr [Var]
letVarsT = letT bindVarsT mempty (\ vs () -> vs)

-- | List the variables bound by a recursive let expression.
letRecVarsT :: TranslateH CoreExpr [Var]
letRecVarsT = letT recVarsT mempty (\ vs () -> vs)

-- | Return the variable bound by a non-recursive let expression.
letNonRecVarT :: TranslateH CoreExpr Var
letNonRecVarT = letT nonRecVarT mempty (\ v () -> v)

-- | List all variables bound by a case expression (in the alternatives and the wildcard binder).
caseVarsT :: TranslateH CoreExpr [Var]
caseVarsT = caseT mempty (\ _ -> altVarsT) (\ () v _ vss -> v : nub (concat vss))

-- | Return the case wildcard binder.
caseWildVarT :: TranslateH CoreExpr Var
caseWildVarT = caseT mempty (\ _ -> return ()) (\ () v _ _ -> v)

-- | List the variables bound by all alternatives in a case expression.
caseAltVarsT :: TranslateH CoreExpr [[Var]]
caseAltVarsT = caseT mempty (\ _ -> altVarsT) (\ () _ _ vss -> vss)

-- | List the variables bound by a case alternative.
altVarsT :: TranslateH CoreAlt [Var]
altVarsT = altT mempty (\ _ vs () -> vs)

------------------------------------------------------------------------------

-- Need a better error type so that we can factor out the repetition.

-- | Lifted version of 'boundVars'.
boundVarsT :: TranslateH a [Var]
boundVarsT = contextonlyT (return . boundVars)

-- | Find the unique variable bound in the context that matches the given name, failing if it is not unique.
findBoundVarT :: TH.Name -> TranslateH a Var
findBoundVarT nm = prefixFailMsg ("Cannot resolve name " ++ TH.nameBase nm ++ ", ") $
                        do c <- contextT
                           case findBoundVars nm c of
                             []         -> fail "no matching variables in scope."
                             [v]        -> return v
                             _ : _ : _  -> fail "multiple matching variables in scope."

-- | Lookup the name in the 'HermitC' first, then, failing that, in GHC's global reader environment.
findIdT :: TH.Name -> TranslateH a Id
findIdT nm = prefixFailMsg ("Cannot resolve name " ++ TH.nameBase nm ++ ", ") $
             do c <- contextT
                case findBoundVars nm c of
                  []         -> findIdMG nm
                  [v]        -> return v
                  _ : _ : _  -> fail "multiple matching variables in scope."

findIdMG :: TH.Name -> TranslateH a Id
findIdMG nm = contextonlyT $ \ c ->
    case filter isValName $ findNameFromTH (mg_rdr_env $ hermitModGuts c) nm of
      []  -> fail $ "variable not in scope."
      [n] -> lookupId n
      ns  -> do dynFlags <- getDynFlags
                fail $ "multiple matches found:\n" ++ intercalate ", " (map (showPpr dynFlags) ns)

------------------------------------------------------------------------------

-- | Constructs a common error message.
--   Argument 'String' should be the desired form of the expression.
wrongExprForm :: String -> String
wrongExprForm form = "Expression does not have the form: " ++ form

------------------------------------------------------------------------------