{-# LANGUAGE DataKinds       #-}
{-# LANGUAGE RecordWildCards #-}

module Data.Record.Anon.Internal.Plugin.TC.Solver (
    solve
  ) where

import Data.Bifunctor
import Data.Maybe (catMaybes)
import Data.Traversable (forM)

import Data.Record.Anon.Internal.Plugin.TC.Constraints.AllFields
import Data.Record.Anon.Internal.Plugin.TC.Constraints.KnownFields
import Data.Record.Anon.Internal.Plugin.TC.Constraints.KnownHash
import Data.Record.Anon.Internal.Plugin.TC.Constraints.RowHasField
import Data.Record.Anon.Internal.Plugin.TC.Constraints.SubRow
import Data.Record.Anon.Internal.Plugin.TC.GhcTcPluginAPI
import Data.Record.Anon.Internal.Plugin.TC.NameResolution
import Data.Record.Anon.Internal.Plugin.TC.Parsing
import Data.Record.Anon.Internal.Plugin.TC.TyConSubst

{-------------------------------------------------------------------------------
  Top-level solver
-------------------------------------------------------------------------------}

solve :: ResolvedNames -> TcPluginSolver
solve :: ResolvedNames -> TcPluginSolver
solve ResolvedNames
rn [Ct]
given [Ct]
wanted =
--  trace _debugInput  $
--  trace _debugParsed $
    do ([(EvTerm, Ct)]
solved, [Ct]
new) <- ([(Maybe (EvTerm, Ct), [Ct])] -> ([(EvTerm, Ct)], [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct])
forall a b. (a -> b) -> TcPluginM 'Solve a -> TcPluginM 'Solve b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)])
-> ([[Ct]] -> [Ct])
-> ([Maybe (EvTerm, Ct)], [[Ct]])
-> ([(EvTerm, Ct)], [Ct])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [Maybe (EvTerm, Ct)] -> [(EvTerm, Ct)]
forall a. [Maybe a] -> [a]
catMaybes [[Ct]] -> [Ct]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (([Maybe (EvTerm, Ct)], [[Ct]]) -> ([(EvTerm, Ct)], [Ct]))
-> ([(Maybe (EvTerm, Ct), [Ct])] -> ([Maybe (EvTerm, Ct)], [[Ct]]))
-> [(Maybe (EvTerm, Ct), [Ct])]
-> ([(EvTerm, Ct)], [Ct])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Maybe (EvTerm, Ct), [Ct])] -> ([Maybe (EvTerm, Ct)], [[Ct]])
forall a b. [(a, b)] -> ([a], [b])
unzip) (TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
 -> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
-> TcPluginM 'Solve ([(EvTerm, Ct)], [Ct])
forall a b. (a -> b) -> a -> b
$ [TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]]
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall (m :: * -> *) a. Applicative m => [m [a]] -> m [a]
concatM [
           [(Ct, GenLocated CtLoc CAllFields)]
-> ((Ct, GenLocated CtLoc CAllFields)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ct, GenLocated CtLoc CAllFields)]
parsedAllFields   (((Ct, GenLocated CtLoc CAllFields)
  -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
 -> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])])
-> ((Ct, GenLocated CtLoc CAllFields)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall a b. (a -> b) -> a -> b
$ (Ct
 -> GenLocated CtLoc CAllFields
 -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> (Ct, GenLocated CtLoc CAllFields)
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (ResolvedNames
-> Ct
-> GenLocated CtLoc CAllFields
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveAllFields   ResolvedNames
rn)
         , [(Ct, GenLocated CtLoc CKnownFields)]
-> ((Ct, GenLocated CtLoc CKnownFields)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ct, GenLocated CtLoc CKnownFields)]
parsedKnownFields (((Ct, GenLocated CtLoc CKnownFields)
  -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
 -> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])])
-> ((Ct, GenLocated CtLoc CKnownFields)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall a b. (a -> b) -> a -> b
$ (Ct
 -> GenLocated CtLoc CKnownFields
 -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> (Ct, GenLocated CtLoc CKnownFields)
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (ResolvedNames
-> Ct
-> GenLocated CtLoc CKnownFields
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveKnownFields ResolvedNames
rn)
         , [(Ct, GenLocated CtLoc CKnownHash)]
-> ((Ct, GenLocated CtLoc CKnownHash)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ct, GenLocated CtLoc CKnownHash)]
parsedKnownHash   (((Ct, GenLocated CtLoc CKnownHash)
  -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
 -> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])])
-> ((Ct, GenLocated CtLoc CKnownHash)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall a b. (a -> b) -> a -> b
$ (Ct
 -> GenLocated CtLoc CKnownHash
 -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> (Ct, GenLocated CtLoc CKnownHash)
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (ResolvedNames
-> Ct
-> GenLocated CtLoc CKnownHash
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveKnownHash   ResolvedNames
rn)
         , [(Ct, GenLocated CtLoc CRowHasField)]
-> ((Ct, GenLocated CtLoc CRowHasField)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ct, GenLocated CtLoc CRowHasField)]
parsedRowHasField (((Ct, GenLocated CtLoc CRowHasField)
  -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
 -> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])])
-> ((Ct, GenLocated CtLoc CRowHasField)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall a b. (a -> b) -> a -> b
$ (Ct
 -> GenLocated CtLoc CRowHasField
 -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> (Ct, GenLocated CtLoc CRowHasField)
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (ResolvedNames
-> Ct
-> GenLocated CtLoc CRowHasField
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveRowHasField ResolvedNames
rn)
         , [(Ct, GenLocated CtLoc CSubRow)]
-> ((Ct, GenLocated CtLoc CSubRow)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Ct, GenLocated CtLoc CSubRow)]
parsedSubRow      (((Ct, GenLocated CtLoc CSubRow)
  -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
 -> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])])
-> ((Ct, GenLocated CtLoc CSubRow)
    -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> TcPluginM 'Solve [(Maybe (EvTerm, Ct), [Ct])]
forall a b. (a -> b) -> a -> b
$ (Ct
 -> GenLocated CtLoc CSubRow
 -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct]))
-> (Ct, GenLocated CtLoc CSubRow)
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (ResolvedNames
-> Ct
-> GenLocated CtLoc CSubRow
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveSubRow      ResolvedNames
rn)
         ]
       TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a. a -> TcPluginM 'Solve a
forall (m :: * -> *) a. Monad m => a -> m a
return (TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult)
-> TcPluginSolveResult -> TcPluginM 'Solve TcPluginSolveResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginSolveResult
TcPluginOk [(EvTerm, Ct)]
solved [Ct]
new
  where
    tcs :: TyConSubst
    tcs :: TyConSubst
tcs = [Ct] -> TyConSubst
mkTyConSubst [Ct]
given

    parsedAllFields   :: [(Ct, GenLocated CtLoc CAllFields)]
    parsedKnownFields :: [(Ct, GenLocated CtLoc CKnownFields)]
    parsedKnownHash   :: [(Ct, GenLocated CtLoc CKnownHash)]
    parsedRowHasField :: [(Ct, GenLocated CtLoc CRowHasField)]
    parsedSubRow      :: [(Ct, GenLocated CtLoc CSubRow)]

    parsedAllFields :: [(Ct, GenLocated CtLoc CAllFields)]
parsedAllFields   = (Ct -> ParseResult Void (Ct, GenLocated CtLoc CAllFields))
-> [Ct] -> [(Ct, GenLocated CtLoc CAllFields)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CAllFields))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CAllFields)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CAllFields)
parseAllFields   TyConSubst
tcs ResolvedNames
rn)) [Ct]
wanted
    parsedKnownFields :: [(Ct, GenLocated CtLoc CKnownFields)]
parsedKnownFields = (Ct -> ParseResult Void (Ct, GenLocated CtLoc CKnownFields))
-> [Ct] -> [(Ct, GenLocated CtLoc CKnownFields)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CKnownFields))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CKnownFields)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CKnownFields)
parseKnownFields TyConSubst
tcs ResolvedNames
rn)) [Ct]
wanted
    parsedKnownHash :: [(Ct, GenLocated CtLoc CKnownHash)]
parsedKnownHash   = (Ct -> ParseResult Void (Ct, GenLocated CtLoc CKnownHash))
-> [Ct] -> [(Ct, GenLocated CtLoc CKnownHash)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CKnownHash))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CKnownHash)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CKnownHash)
parseKnownHash   TyConSubst
tcs ResolvedNames
rn)) [Ct]
wanted
    parsedRowHasField :: [(Ct, GenLocated CtLoc CRowHasField)]
parsedRowHasField = (Ct -> ParseResult Void (Ct, GenLocated CtLoc CRowHasField))
-> [Ct] -> [(Ct, GenLocated CtLoc CRowHasField)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CRowHasField))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CRowHasField)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (HasCallStack =>
TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CRowHasField)
TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CRowHasField)
parseRowHasField TyConSubst
tcs ResolvedNames
rn)) [Ct]
wanted
    parsedSubRow :: [(Ct, GenLocated CtLoc CSubRow)]
parsedSubRow      = (Ct -> ParseResult Void (Ct, GenLocated CtLoc CSubRow))
-> [Ct] -> [(Ct, GenLocated CtLoc CSubRow)]
forall a b. (a -> ParseResult Void b) -> [a] -> [b]
parseAll' ((Ct -> ParseResult Void (GenLocated CtLoc CSubRow))
-> Ct -> ParseResult Void (Ct, GenLocated CtLoc CSubRow)
forall a e b. (a -> ParseResult e b) -> a -> ParseResult e (a, b)
withOrig (TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CSubRow)
parseSubRow      TyConSubst
tcs ResolvedNames
rn)) [Ct]
wanted

    _debugInput :: String
    _debugInput :: String
_debugInput = [String] -> String
unlines [
          String
"*** input"
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
              String
"given:"
            , SDoc -> String
showSDocUnsafe ([Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
given)
            ]
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
              String
"wanted: "
            , SDoc -> String
showSDocUnsafe ([Ct] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Ct]
wanted)
            ]
        ]

    _debugParsed :: String
    _debugParsed :: String
_debugParsed = [String] -> String
unlines [
          String
"*** parsed"
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"parsedAllFields:   ", SDoc -> String
showSDocUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ [(Ct, GenLocated CtLoc CAllFields)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(Ct, GenLocated CtLoc CAllFields)]
parsedAllFields]
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"parsedKnownFields: ", SDoc -> String
showSDocUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ [(Ct, GenLocated CtLoc CKnownFields)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(Ct, GenLocated CtLoc CKnownFields)]
parsedKnownFields]
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"parsedKnownHash:   ", SDoc -> String
showSDocUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ [(Ct, GenLocated CtLoc CKnownFields)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(Ct, GenLocated CtLoc CKnownFields)]
parsedKnownFields]
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"parsedRowHasField: ", SDoc -> String
showSDocUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ [(Ct, GenLocated CtLoc CRowHasField)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(Ct, GenLocated CtLoc CRowHasField)]
parsedRowHasField]
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"parsedSubRow:      ", SDoc -> String
showSDocUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ [(Ct, GenLocated CtLoc CSubRow)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [(Ct, GenLocated CtLoc CSubRow)]
parsedSubRow]
        , [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"tcs (TyConSubst):  ", SDoc -> String
showSDocUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ TyConSubst -> SDoc
forall a. Outputable a => a -> SDoc
ppr TyConSubst
tcs]
        ]

{-------------------------------------------------------------------------------
  Auxiliary
-------------------------------------------------------------------------------}

concatM :: Applicative m => [m [a]] -> m [a]
concatM :: forall (m :: * -> *) a. Applicative m => [m [a]] -> m [a]
concatM = ([[a]] -> [a]) -> m [[a]] -> m [a]
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (m [[a]] -> m [a]) -> ([m [a]] -> m [[a]]) -> [m [a]] -> m [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [m [a]] -> m [[a]]
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
forall (f :: * -> *) a. Applicative f => [f a] -> f [a]
sequenceA