{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections       #-}

module Data.Record.Anon.Internal.Plugin.TC.Constraints.AllFields (
    CAllFields(..)
  , parseAllFields
  , solveAllFields
  ) where

import Data.Bifunctor
import Data.Foldable (toList)
import Data.Void

import Data.Record.Anon.Internal.Plugin.TC.Row.KnownField (KnownField(..))
import Data.Record.Anon.Internal.Plugin.TC.Row.KnownRow (KnownRow)
import Data.Record.Anon.Internal.Plugin.TC.Row.ParsedRow (Fields)
import Data.Record.Anon.Internal.Plugin.TC.GhcTcPluginAPI
import Data.Record.Anon.Internal.Plugin.TC.NameResolution
import Data.Record.Anon.Internal.Plugin.TC.Parsing
import Data.Record.Anon.Internal.Plugin.TC.TyConSubst

import qualified Data.Record.Anon.Internal.Plugin.TC.Row.KnownRow  as KnownRow
import qualified Data.Record.Anon.Internal.Plugin.TC.Row.ParsedRow as ParsedRow

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

-- | Parsed form of @AllFields c r@
data CAllFields = CAllFields {
      -- | Fields of the record (parsed form of @r@)
      CAllFields -> Fields
allFieldsParsedFields :: Fields

      -- | Type of the fields (@r@)
    , CAllFields -> Type
allFieldsTypeFields :: Type

      -- | Constraint required for each field (@c@)
    , CAllFields -> Type
allFieldsTypeConstraint :: Type

      -- | Constraint argument kind (the @k@ in @c :: k -> Constraint@)
    , CAllFields -> Type
allFieldsTypeKind :: Type
    }

{-------------------------------------------------------------------------------
  Outputable
-------------------------------------------------------------------------------}

instance Outputable CAllFields where
  ppr :: CAllFields -> SDoc
ppr (CAllFields Fields
parsedFields Type
typeConstraint Type
typeKind Type
typeFields) = SDoc -> SDoc
parens forall a b. (a -> b) -> a -> b
$
      String -> SDoc
text String
"CAllFields" SDoc -> SDoc -> SDoc
<+> SDoc -> SDoc
braces ([SDoc] -> SDoc
vcat [
          String -> SDoc
text String
"allFieldsParsedFields"   SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"=" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Fields
parsedFields
        , String -> SDoc
text String
"allFieldsTypeFields    " SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"=" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
typeFields
        , String -> SDoc
text String
"allFieldsTypeConstraint" SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"=" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
typeConstraint
        , String -> SDoc
text String
"allFieldsTypeKind"       SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"=" SDoc -> SDoc -> SDoc
<+> forall a. Outputable a => a -> SDoc
ppr Type
typeKind
        ])

{-------------------------------------------------------------------------------
  Parser
-------------------------------------------------------------------------------}

parseAllFields ::
     TyConSubst
  -> ResolvedNames
  -> Ct
  -> ParseResult Void (GenLocated CtLoc CAllFields)
parseAllFields :: TyConSubst
-> ResolvedNames
-> Ct
-> ParseResult Void (GenLocated CtLoc CAllFields)
parseAllFields TyConSubst
tcs rn :: ResolvedNames
rn@ResolvedNames{Class
DataCon
TyCon
Id
tyConSimpleFieldTypes :: ResolvedNames -> TyCon
tyConPair :: ResolvedNames -> TyCon
tyConFieldTypes :: ResolvedNames -> TyCon
tyConMerge :: ResolvedNames -> TyCon
tyConDictAny :: ResolvedNames -> TyCon
idUnsafeCoerce :: ResolvedNames -> Id
idEvidenceSubRow :: ResolvedNames -> Id
idEvidenceRowHasField :: ResolvedNames -> Id
idEvidenceKnownHash :: ResolvedNames -> Id
idEvidenceKnownFields :: ResolvedNames -> Id
idEvidenceAllFields :: ResolvedNames -> Id
dataConDictAny :: ResolvedNames -> DataCon
clsSubRow :: ResolvedNames -> Class
clsRowHasField :: ResolvedNames -> Class
clsKnownHash :: ResolvedNames -> Class
clsKnownFields :: ResolvedNames -> Class
clsAllFields :: ResolvedNames -> Class
tyConSimpleFieldTypes :: TyCon
tyConPair :: TyCon
tyConFieldTypes :: TyCon
tyConMerge :: TyCon
tyConDictAny :: TyCon
idUnsafeCoerce :: Id
idEvidenceSubRow :: Id
idEvidenceRowHasField :: Id
idEvidenceKnownHash :: Id
idEvidenceKnownFields :: Id
idEvidenceAllFields :: Id
dataConDictAny :: DataCon
clsSubRow :: Class
clsRowHasField :: Class
clsKnownHash :: Class
clsKnownFields :: Class
clsAllFields :: Class
..} =
    forall a e.
HasCallStack =>
Class
-> ([Type] -> Maybe a) -> Ct -> ParseResult e (GenLocated CtLoc a)
parseConstraint' Class
clsAllFields forall a b. (a -> b) -> a -> b
$ \case
      [Type
k, Type
r, Type
c] -> do
        Fields
fields <- TyConSubst -> ResolvedNames -> Type -> Maybe Fields
ParsedRow.parseFields TyConSubst
tcs ResolvedNames
rn Type
r
        forall (m :: * -> *) a. Monad m => a -> m a
return CAllFields {
            allFieldsParsedFields :: Fields
allFieldsParsedFields   = Fields
fields
          , allFieldsTypeFields :: Type
allFieldsTypeFields     = Type
r
          , allFieldsTypeConstraint :: Type
allFieldsTypeConstraint = Type
c
          , allFieldsTypeKind :: Type
allFieldsTypeKind       = Type
k
          }
      [Type]
_invalidNumArgs ->
        forall a. Maybe a
Nothing

{-------------------------------------------------------------------------------
  Evidence
-------------------------------------------------------------------------------}

-- | Construct evidence
--
-- For each field we need an evidence variable corresponding to the evidence
-- that that fields satisfies the constraint.
evidenceAllFields ::
     ResolvedNames
  -> CAllFields
  -> KnownRow (Type, EvVar)
  -> TcPluginM 'Solve EvTerm
evidenceAllFields :: ResolvedNames
-> CAllFields -> KnownRow (Type, Id) -> TcPluginM 'Solve EvTerm
evidenceAllFields ResolvedNames{Class
DataCon
TyCon
Id
tyConSimpleFieldTypes :: TyCon
tyConPair :: TyCon
tyConFieldTypes :: TyCon
tyConMerge :: TyCon
tyConDictAny :: TyCon
idUnsafeCoerce :: Id
idEvidenceSubRow :: Id
idEvidenceRowHasField :: Id
idEvidenceKnownHash :: Id
idEvidenceKnownFields :: Id
idEvidenceAllFields :: Id
dataConDictAny :: DataCon
clsSubRow :: Class
clsRowHasField :: Class
clsKnownHash :: Class
clsKnownFields :: Class
clsAllFields :: Class
tyConSimpleFieldTypes :: ResolvedNames -> TyCon
tyConPair :: ResolvedNames -> TyCon
tyConFieldTypes :: ResolvedNames -> TyCon
tyConMerge :: ResolvedNames -> TyCon
tyConDictAny :: ResolvedNames -> TyCon
idUnsafeCoerce :: ResolvedNames -> Id
idEvidenceSubRow :: ResolvedNames -> Id
idEvidenceRowHasField :: ResolvedNames -> Id
idEvidenceKnownHash :: ResolvedNames -> Id
idEvidenceKnownFields :: ResolvedNames -> Id
idEvidenceAllFields :: ResolvedNames -> Id
dataConDictAny :: ResolvedNames -> DataCon
clsSubRow :: ResolvedNames -> Class
clsRowHasField :: ResolvedNames -> Class
clsKnownHash :: ResolvedNames -> Class
clsKnownFields :: ResolvedNames -> Class
clsAllFields :: ResolvedNames -> Class
..} CAllFields{Type
Fields
allFieldsTypeKind :: Type
allFieldsTypeConstraint :: Type
allFieldsTypeFields :: Type
allFieldsParsedFields :: Fields
allFieldsTypeKind :: CAllFields -> Type
allFieldsTypeConstraint :: CAllFields -> Type
allFieldsTypeFields :: CAllFields -> Type
allFieldsParsedFields :: CAllFields -> Fields
..} KnownRow (Type, Id)
fields = do
    [EvExpr]
fields' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KnownField (Type, Id) -> TcPluginM 'Solve EvExpr
dictForField (forall a. KnownRow a -> [KnownField a]
KnownRow.toList KnownRow (Type, Id)
fields)
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
      DataCon -> [Type] -> [EvExpr] -> EvTerm
evDataConApp
        (Class -> DataCon
classDataCon Class
clsAllFields)
        [Type]
typeArgsEvidence
        [ EvExpr -> [EvExpr] -> EvExpr
mkCoreApps (forall b. Id -> Expr b
Var Id
idEvidenceAllFields) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
              forall a b. (a -> b) -> [a] -> [b]
map forall b. Type -> Expr b
Type [Type]
typeArgsEvidence
            , [Type -> [EvExpr] -> EvExpr
mkListExpr (TyCon -> [Type] -> Type
mkTyConApp TyCon
tyConDictAny [Type]
typeArgsDict) [EvExpr]
fields']
            ]
        ]
  where
    -- Type arguments to @Dict@ and to @AllFields@
    typeArgsDict, typeArgsEvidence :: [Type]
    typeArgsDict :: [Type]
typeArgsDict = [
          Type
allFieldsTypeKind
        , Type
allFieldsTypeConstraint
        ]
    typeArgsEvidence :: [Type]
typeArgsEvidence = [
          Type
allFieldsTypeKind
        , Type
allFieldsTypeFields
        , Type
allFieldsTypeConstraint
        ]

    dictForField :: KnownField (Type, EvVar) -> TcPluginM 'Solve EvExpr
    dictForField :: KnownField (Type, Id) -> TcPluginM 'Solve EvExpr
dictForField KnownField{ knownFieldInfo :: forall a. KnownField a -> a
knownFieldInfo = (Type
fieldType, Id
dict) } = do
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ DataCon -> [EvExpr] -> EvExpr
mkCoreConApps DataCon
dataConDictAny forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
            forall a b. (a -> b) -> [a] -> [b]
map forall b. Type -> Expr b
Type [Type]
typeArgsDict
          , [ -- We have a dictionary of type @c a@ from the evidence we get
              -- from ghc; we cast it to @c Any@ to serve as arg to @DictAny@.
               EvExpr -> [EvExpr] -> EvExpr
mkCoreApps (forall b. Id -> Expr b
Var Id
idUnsafeCoerce) [
                forall b. Type -> Expr b
Type forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type
mkAppTy Type
allFieldsTypeConstraint Type
fieldType
              , forall b. Type -> Expr b
Type forall a b. (a -> b) -> a -> b
$ Type -> Type -> Type
mkAppTy Type
allFieldsTypeConstraint Type
anyAtKind
              , forall b. Id -> Expr b
Var Id
dict
              ]
            ]
          ]

    -- Any at kind @k@
    anyAtKind :: Type
    anyAtKind :: Type
anyAtKind = TyCon -> [Type] -> Type
mkTyConApp TyCon
anyTyCon [Type
allFieldsTypeKind]

{-------------------------------------------------------------------------------
  Solver
-------------------------------------------------------------------------------}

solveAllFields ::
     ResolvedNames
  -> Ct
  -> GenLocated CtLoc CAllFields
  -> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveAllFields :: ResolvedNames
-> Ct
-> GenLocated CtLoc CAllFields
-> TcPluginM 'Solve (Maybe (EvTerm, Ct), [Ct])
solveAllFields ResolvedNames
rn Ct
orig (L CtLoc
loc cr :: CAllFields
cr@CAllFields{Type
Fields
allFieldsTypeKind :: Type
allFieldsTypeConstraint :: Type
allFieldsTypeFields :: Type
allFieldsParsedFields :: Fields
allFieldsTypeKind :: CAllFields -> Type
allFieldsTypeConstraint :: CAllFields -> Type
allFieldsTypeFields :: CAllFields -> Type
allFieldsParsedFields :: CAllFields -> Fields
..}) = do
    case Fields -> Maybe (KnownRow Type)
ParsedRow.allKnown Fields
allFieldsParsedFields of
      Maybe (KnownRow Type)
Nothing ->
        forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. Maybe a
Nothing, [])
      Just KnownRow Type
fields -> do
        KnownRow (Type, CtEvidence)
fields' :: KnownRow (Type, CtEvidence)
           <- forall (m :: * -> *) a b.
Applicative m =>
KnownRow a -> (FieldName -> a -> m b) -> m (KnownRow b)
KnownRow.traverse KnownRow Type
fields forall a b. (a -> b) -> a -> b
$ \FieldName
_nm Type
typ -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type
typ,) forall a b. (a -> b) -> a -> b
$
                forall (m :: * -> *).
MonadTcPluginWork m =>
CtLoc -> Type -> m CtEvidence
newWanted CtLoc
loc forall a b. (a -> b) -> a -> b
$
                  Type -> Type -> Type
mkAppTy Type
allFieldsTypeConstraint Type
typ
        EvTerm
ev <- ResolvedNames
-> CAllFields -> KnownRow (Type, Id) -> TcPluginM 'Solve EvTerm
evidenceAllFields ResolvedNames
rn CAllFields
cr forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second CtEvidence -> Id
getEvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KnownRow (Type, CtEvidence)
fields'
        forall (m :: * -> *) a. Monad m => a -> m a
return (
            forall a. a -> Maybe a
Just (EvTerm
ev, Ct
orig)
          , forall a b. (a -> b) -> [a] -> [b]
map (CtEvidence -> Ct
mkNonCanonical forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) (forall (t :: * -> *) a. Foldable t => t a -> [a]
toList KnownRow (Type, CtEvidence)
fields')
          )
  where
    getEvVar :: CtEvidence -> EvVar
    getEvVar :: CtEvidence -> Id
getEvVar CtEvidence
ct = case CtEvidence -> TcEvDest
ctev_dest CtEvidence
ct of
      EvVarDest Id
var -> Id
var
      HoleDest  CoercionHole
_   -> forall a. HasCallStack => String -> a
error String
"impossible (we don't ask for primitive equality)"