module TypeLet.Plugin (plugin) where

import Prelude hiding (cycle)

import Data.Traversable (forM)

import GHC.Plugins (Plugin(..), defaultPlugin, purePlugin)

import TypeLet.Plugin.Constraints
import TypeLet.Plugin.GhcTcPluginAPI
import TypeLet.Plugin.NameResolution
import TypeLet.Plugin.Substitution

{-------------------------------------------------------------------------------
  Top-level plumbing
-------------------------------------------------------------------------------}

plugin :: Plugin
plugin :: Plugin
plugin = Plugin
defaultPlugin {
      pluginRecompile :: [CommandLineOption] -> IO PluginRecompile
pluginRecompile  = [CommandLineOption] -> IO PluginRecompile
purePlugin
    , tcPlugin :: TcPlugin
tcPlugin         = \[CommandLineOption]
_cmdline -> TcPlugin -> Maybe TcPlugin
forall a. a -> Maybe a
Just (TcPlugin -> Maybe TcPlugin)
-> (TcPlugin -> TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcPlugin -> TcPlugin
mkTcPlugin (TcPlugin -> Maybe TcPlugin) -> TcPlugin -> Maybe TcPlugin
forall a b. (a -> b) -> a -> b
$ TcPlugin :: forall s.
TcPluginM 'Init s
-> (s -> TcPluginSolver)
-> (s -> UniqFM TcPluginRewriter)
-> (s -> TcPluginM 'Stop ())
-> TcPlugin
TcPlugin {
                               tcPluginInit :: TcPluginM 'Init ResolvedNames
tcPluginInit    = TcPluginM 'Init ResolvedNames
resolveNames
                             , tcPluginSolve :: ResolvedNames -> TcPluginSolver
tcPluginSolve   = ResolvedNames -> TcPluginSolver
solve
                             , tcPluginRewrite :: ResolvedNames -> UniqFM TcPluginRewriter
tcPluginRewrite = \ResolvedNames
_st -> UniqFM TcPluginRewriter
forall elt. UniqFM elt
emptyUFM
                             , tcPluginStop :: ResolvedNames -> TcPluginM 'Stop ()
tcPluginStop    = \ResolvedNames
_st -> () -> TcPluginM 'Stop ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                             }
    }

{-------------------------------------------------------------------------------
  Constraint resolution

  General approach: regard @Let@ constraints as defining a substitution, and
  then resolve @Equal@ constraints by /applying/ that substitution and
  simplifying to a derived equality constraint (derived instead of a new wanted
  constraint, because we don't actually need the evidence).
-------------------------------------------------------------------------------}

-- | Main interface to constraint resolution
--
-- NOTE: For now we are completely ignoring the derived constraints.
solve :: ResolvedNames -> TcPluginSolver
solve :: ResolvedNames -> TcPluginSolver
solve ResolvedNames
rn [Ct]
given [Ct]
wanted
  | [Ct] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Ct]
wanted = ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
rn [Ct]
given
  | Bool
otherwise   = ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted

-- | Simplify givens
--
-- We (currently?) never simplify any givens, so we just two empty lists,
-- indicating that there no constraints were removed and none got added.
simplifyGivens ::
     ResolvedNames  -- ^ Result of name resolution (during init)
  -> [Ct]           -- ^ Given constraints
  -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens :: ResolvedNames -> [Ct] -> TcPluginM 'Solve TcPluginSolveResult
simplifyGivens ResolvedNames
_st [Ct]
_given = TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [] []

-- | Simplify wanteds
--
-- This function provides the key functionality of the plugin.
--
-- We resolve 'Equal' constraints to /nominal/ equality constraints: we want
-- 'cast' to resolve @Let@ bindings, but not additionally work as 'coerce'.
simplifyWanteds ::
     ResolvedNames  -- ^ Result of name resolution (during init)
  -> [Ct]           -- ^ Given constraints
  -> [Ct]           -- ^ Wanted constraints
  -> TcPluginM 'Solve TcPluginSolveResult
simplifyWanteds :: ResolvedNames -> TcPluginSolver
simplifyWanteds ResolvedNames
rn [Ct]
given [Ct]
wanted = do
    case (Ct
 -> ParseResult
      (GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet))
-> [Ct]
-> Either (GenLocated CtLoc InvalidLet) [GenLocated CtLoc CLet]
forall e a b. (a -> ParseResult e b) -> [a] -> Either e [b]
parseAll (ResolvedNames
-> Ct
-> ParseResult
     (GenLocated CtLoc InvalidLet) (GenLocated CtLoc CLet)
parseLet ResolvedNames
rn) [Ct]
given of
      Left GenLocated CtLoc InvalidLet
err ->
        GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (GenLocated CtLoc TcPluginErrorMessage
 -> TcPluginM 'Solve TcPluginSolveResult)
-> GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ InvalidLet -> TcPluginErrorMessage
formatInvalidLet (InvalidLet -> TcPluginErrorMessage)
-> GenLocated CtLoc InvalidLet
-> GenLocated CtLoc TcPluginErrorMessage
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenLocated CtLoc InvalidLet
err
      Right [GenLocated CtLoc CLet]
lets -> do
        case [GenLocated CtLoc CLet]
-> Either (Cycle (GenLocated CtLoc CLet)) TCvSubst
letsToSubst [GenLocated CtLoc CLet]
lets of
          Left Cycle (GenLocated CtLoc CLet)
cycle ->
            GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (GenLocated CtLoc TcPluginErrorMessage
 -> TcPluginM 'Solve TcPluginSolveResult)
-> GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ Cycle (GenLocated CtLoc CLet)
-> GenLocated CtLoc TcPluginErrorMessage
formatLetCycle Cycle (GenLocated CtLoc CLet)
cycle
          Right TCvSubst
subst -> do
            ([(EvTerm, Ct)]
solved, [Ct]
new) <- ([((EvTerm, Ct), Ct)] -> ([(EvTerm, Ct)], [Ct]))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [((EvTerm, Ct), Ct)] -> ([(EvTerm, Ct)], [Ct])
forall a b. [(a, b)] -> ([a], [b])
unzip (TcPluginM 'Solve [((EvTerm, Ct), Ct)]
 -> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct]))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct])
forall a b. (a -> b) -> a -> b
$
              [(Ct, GenLocated CtLoc CEqual)]
-> ((Ct, GenLocated CtLoc CEqual)
    -> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ((Ct -> ParseResult Void (Ct, GenLocated CtLoc CEqual))
-> [Ct] -> [(Ct, GenLocated CtLoc CEqual)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CEqual))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CEqual)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (ResolvedNames -> Ct -> ParseResult Void (GenLocated CtLoc CEqual)
parseEqual ResolvedNames
rn)) [Ct]
wanted) (((Ct, GenLocated CtLoc CEqual)
  -> TcPluginM 'Solve ((EvTerm, Ct), Ct))
 -> TcPluginM 'Solve [((EvTerm, Ct), Ct)])
-> ((Ct, GenLocated CtLoc CEqual)
    -> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> TcPluginM 'Solve [((EvTerm, Ct), Ct)]
forall a b. (a -> b) -> a -> b
$
                (Ct
 -> GenLocated CtLoc CEqual -> TcPluginM 'Solve ((EvTerm, Ct), Ct))
-> (Ct, GenLocated CtLoc CEqual)
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual TCvSubst
subst)
            TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
  where
    -- Work-around bug in ghc, making sure the location is set correctly
    newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
    newWanted' :: CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
w = CtLoc -> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *) a. MonadTcPluginWork m => CtLoc -> m a -> m a
setCtLocM CtLoc
l (TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve CtEvidence)
-> TcPluginM 'Solve CtEvidence -> TcPluginM 'Solve CtEvidence
forall a b. (a -> b) -> a -> b
$ CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> PredType -> m CtEvidence
newWanted CtLoc
l PredType
w

    errWith ::
         GenLocated CtLoc TcPluginErrorMessage
      -> TcPluginM 'Solve TcPluginSolveResult
    errWith :: GenLocated CtLoc TcPluginErrorMessage
-> TcPluginM 'Solve TcPluginSolveResult
errWith (L CtLoc
l TcPluginErrorMessage
err) = do
        PredType
errAsTyp <- TcPluginErrorMessage -> TcPluginM 'Solve PredType
forall (m :: * -> *).
MonadTcPluginWork m =>
TcPluginErrorMessage -> m PredType
mkTcPluginErrorTy TcPluginErrorMessage
err
        CtEvidence -> TcPluginSolveResult
mkErr (CtEvidence -> TcPluginSolveResult)
-> TcPluginM 'Solve CtEvidence
-> TcPluginM 'Solve TcPluginSolveResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l PredType
errAsTyp
      where
        mkErr :: CtEvidence -> TcPluginSolveResult
        mkErr :: CtEvidence -> TcPluginSolveResult
mkErr = [Ct] -> TcPluginSolveResult
TcPluginContradiction ([Ct] -> TcPluginSolveResult)
-> (CtEvidence -> [Ct]) -> CtEvidence -> TcPluginSolveResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ct -> [Ct] -> [Ct]
forall a. a -> [a] -> [a]
:[]) (Ct -> [Ct]) -> (CtEvidence -> Ct) -> CtEvidence -> [Ct]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CtEvidence -> Ct
mkNonCanonical

    -- Solve an Equal constraint by applying the substitution and turning it
    -- into a nominal equality constraint
    solveEqual ::
         TCvSubst
      -> Ct                       -- Original Equal constraint
      -> GenLocated CtLoc CEqual  -- Parsed Equal constraint
      -> TcPluginM 'Solve ((EvTerm, Ct), Ct)
    solveEqual :: TCvSubst
-> Ct
-> GenLocated CtLoc CEqual
-> TcPluginM 'Solve ((EvTerm, Ct), Ct)
solveEqual TCvSubst
subst Ct
orig (L CtLoc
l CEqual
parsed) = do
        CtEvidence
ev <- CtLoc -> PredType -> TcPluginM 'Solve CtEvidence
newWanted' CtLoc
l (PredType -> TcPluginM 'Solve CtEvidence)
-> PredType -> TcPluginM 'Solve CtEvidence
forall a b. (a -> b) -> a -> b
$
                Role -> PredType -> PredType -> PredType
mkPrimEqPredRole
                  Role
Nominal
                  (HasCallStack => TCvSubst -> PredType -> PredType
TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (CEqual -> PredType
equalLHS CEqual
parsed))
                  (HasCallStack => TCvSubst -> PredType -> PredType
TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (CEqual -> PredType
equalRHS CEqual
parsed))
        ((EvTerm, Ct), Ct) -> TcPluginM 'Solve ((EvTerm, Ct), Ct)
forall (m :: * -> *) a. Monad m => a -> m a
return (
            (ResolvedNames -> CEqual -> EvTerm
evidenceEqual ResolvedNames
rn CEqual
parsed, Ct
orig)
          , CtEvidence -> Ct
mkNonCanonical CtEvidence
ev
          )