module TypeLet.Plugin (plugin) where

import Prelude hiding (cycle)

import Data.Traversable (forM)

import GHC.Plugins (Plugin(..), defaultPlugin, purePlugin)

import TypeLet.Plugin.Constraints
import TypeLet.Plugin.GhcTcPluginAPI
import TypeLet.Plugin.NameResolution
import TypeLet.Plugin.Substitution

{-------------------------------------------------------------------------------
  Top-level plumbing
-------------------------------------------------------------------------------}

plugin :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin {
      pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile  = [CommandLineOption] -> IO PluginRecompile
purePlugin
    , tcPlugin :: TcPlugin
tcPlugin         = \[CommandLineOption]
_cmdline -> forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcPlugin -> TcPlugin
mkTcPlugin forall a b. (a -> b) -> a -> b
$ TcPlugin {
                               tcPluginInit :: TcPluginM 'Init ResolvedNames
tcPluginInit    = TcPluginM 'Init ResolvedNames
resolveNames
                             , tcPluginSolve :: ResolvedNames -> TcPluginSolver
tcPluginSolve   = ResolvedNames -> TcPluginSolver
solve
                             , tcPluginRewrite :: ResolvedNames -> UniqFM TyCon TcPluginRewriter
tcPluginRewrite = \ResolvedNames
_st -> forall key elt. UniqFM key elt
emptyUFM
                             , tcPluginStop :: ResolvedNames -> TcPluginM 'Stop ()
tcPluginStop    = \ResolvedNames
_st -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
                             }
    }

{-------------------------------------------------------------------------------
  Constraint resolution

  General approach: regard @Let@ constraints as defining a substitution, and
  then resolve @Equal@ constraints by /applying/ that substitution and
  simplifying to a derived equality constraint (derived instead of a new wanted
  constraint, because we don't actually need the evidence).
-------------------------------------------------------------------------------}

-- | Main interface to constraint resolution
--
-- NOTE: For now we are completely ignoring the derived constraints.
solve :: ResolvedNames -> TcPluginSolver
solve :: ResolvedNames -> TcPluginSolver
solve ResolvedNames
rn [Ct]
given [Ct]
wanted
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Ct]
wanted = ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
rn [Ct]
given
  | Bool
otherwise   = ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted

-- | Simplify givens
--
-- We (currently?) never simplify any givens, so we just two empty lists,
-- indicating that there no constraints were removed and none got added.
simplifyGivens ::
     ResolvedNames  -- ^ Result of name resolution (during init)
  -> [Ct]           -- ^ Given constraints
  -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens :: ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
_st [Ct]
_given = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []

-- | Simplify wanteds
--
-- This function provides the key functionality of the plugin.
--
-- We resolve 'Equal' constraints to /nominal/ equality constraints: we want
-- 'cast' to resolve @Let@ bindings, but not additionally work as 'coerce'.
simplifyWanteds ::
     ResolvedNames  -- ^ Result of name resolution (during init)
  -> [Ct]           -- ^ Given constraints
  -> [Ct]           -- ^ Wanted constraints
  -> TcPluginM 'Solve TcPluginSolveResult
simplifyWanteds :: ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted = do
    case forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll (ResolvedNames
-> Ct
-> ParseResult
     (GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet)
parseLet ResolvedNames
rn) [Ct]
given of
      Left GenLocated CtLoc InvalidLet
err ->
        GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith forall a b. (a -> b) -> a -> b
$ InvalidLet -> TcPluginErrorMessage
formatInvalidLet forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenLocated CtLoc InvalidLet
err
      Right [GenLocated CtLoc CLet]
lets -> do
        case [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst [GenLocated CtLoc CLet]
lets of
          Left Cycle (GenLocated CtLoc CLet)
cycle ->
            GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith forall a b. (a -> b) -> a -> b
$ Cycle (GenLocated CtLoc CLet)
-> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle Cycle (GenLocated CtLoc CLet)
cycle
          Right TCvSubst
subst -> do
            ([(EvTerm, Ct)]
solved, [Ct]
new) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' (forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (ResolvedNames -> Ct -> ParseResult Void (GenLocated CtLoc CEqual)
parseEqual ResolvedNames
rn)) [Ct]
wanted) forall a b. (a -> b) -> a -> b
$
                forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual TCvSubst
subst)
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
  where
    -- Work-around bug in ghc, making sure the location is set correctly
    newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
    newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
w = forall (m :: * -> *) a. MonadTcPluginWork m => CtLoc -> m a -> m a
setCtLocM CtLoc
l forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> PredType -> m CtEvidence
newWanted CtLoc
l PredType
w

    errWith ::
         GenLocated CtLoc TcPluginErrorMessage
      -> TcPluginM 'Solve TcPluginSolveResult
    errWith :: GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (L CtLoc
l TcPluginErrorMessage
err) = do
        PredType
errAsTyp <- forall (m :: * -> *).
MonadTcPluginWork m =>
TcPluginErrorMessage -> m PredType
mkTcPluginErrorTy TcPluginErrorMessage
err
        CtEvidence -> TcPluginSolveResult
mkErr forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
errAsTyp
      where
        mkErr :: CtEvidence -> TcPluginSolveResult
        mkErr :: CtEvidence -> TcPluginSolveResult
mkErr = [Ct] -> TcPluginSolveResult
TcPluginContradiction forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. CtEvidence -> Ct
mkNonCanonical

    -- Solve an Equal constraint by applying the substitution and turning it
    -- into a nominal equality constraint
    solveEqual ::
         TCvSubst
      -> Ct                       -- Original Equal constraint
      -> GenLocated CtLoc CEqual  -- Parsed Equal constraint
      -> TcPluginM 'Solve ((EvTerm, Ct), Ct)
    solveEqual :: TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual TCvSubst
subst Ct
orig (L CtLoc
l CEqual
parsed) = do
        CtEvidence
ev <- CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l forall a b. (a -> b) -> a -> b
$
                Role -> PredType -> PredType -> PredType
mkPrimEqPredRole
                  Role
Nominal
                  (HasCallStack => TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (CEqual -> PredType
equalLHS CEqual
parsed))
                  (HasCallStack => TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (CEqual -> PredType
equalRHS CEqual
parsed))
        forall (m :: * -> *) a. Monad m => a -> m a
return (
            (ResolvedNames -> CEqual -> EvTerm
evidenceEqual ResolvedNames
rn CEqual
parsed, Ct
orig)
          , CtEvidence -> Ct
mkNonCanonical CtEvidence
ev
          )