{-# LANGUAGE TupleSections #-}
module Grisette.Internal.TH.DeriveWithHandlers
( deriveWithHandlers,
)
where
import Control.Monad (foldM, unless, when)
import Data.List (transpose)
import qualified Data.Map as M
import Data.Maybe (fromMaybe, mapMaybe)
import Grisette.Internal.TH.DeriveInstanceProvider
( DeriveInstanceProvider (instanceDeclaration),
)
import Grisette.Internal.TH.DeriveTypeParamHandler
( DeriveTypeParamHandler (handleBody, handleTypeParams),
SomeDeriveTypeParamHandler,
)
import Grisette.Internal.TH.Util
( allSameKind,
dropNTypeParam,
reifyDatatypeWithFreshNames,
substDataType,
)
import Language.Haskell.TH (Dec, Name, Q)
import Language.Haskell.TH.Datatype
( ConstructorInfo (constructorFields),
DatatypeInfo (datatypeCons, datatypeVars),
datatypeType,
reifyDatatype,
tvName,
)
transposeMatrix :: Int -> [[a]] -> [[a]]
transposeMatrix :: forall a. Int -> [[a]] -> [[a]]
transposeMatrix Int
n [] = Int -> [a] -> [[a]]
forall a. Int -> a -> [a]
replicate Int
n []
transposeMatrix Int
_ [[a]]
x = [[a]] -> [[a]]
forall a. [[a]] -> [[a]]
transpose [[a]]
x
deriveWithHandlers ::
(DeriveInstanceProvider provider) =>
[SomeDeriveTypeParamHandler] ->
provider ->
Bool ->
Int ->
[Name] ->
Q [Dec]
deriveWithHandlers :: forall provider.
DeriveInstanceProvider provider =>
[SomeDeriveTypeParamHandler]
-> provider -> Bool -> Int -> [Name] -> Q [Dec]
deriveWithHandlers
[SomeDeriveTypeParamHandler]
handlers
provider
provider
Bool
ignoreBodyConstraints
Int
numDroppedTailTypes
[Name]
names = do
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
numDroppedTailTypes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveWithHandlers: numDroppedTailTypes must be non-negative"
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
numDroppedTailTypes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
ignoreBodyConstraints) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$
String
"deriveWithHandlers: ignoreBodyConstraints must be True if "
String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"numDroppedTailTypes > 0"
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Name] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
names) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveWithHandlers: no types provided"
[DatatypeInfo]
datatypes <-
if [Name] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
names Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
then (Name -> Q DatatypeInfo) -> [Name] -> Q [DatatypeInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Name -> Q DatatypeInfo
reifyDatatype [Name]
names
else (Name -> Q DatatypeInfo) -> [Name] -> Q [DatatypeInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Name -> Q DatatypeInfo
reifyDatatypeWithFreshNames [Name]
names
let tyVars :: [[TyVarBndrUnit]]
tyVars =
Int -> [[TyVarBndrUnit]] -> [[TyVarBndrUnit]]
forall a. Int -> [[a]] -> [[a]]
transposeMatrix Int
0 ([[TyVarBndrUnit]] -> [[TyVarBndrUnit]])
-> [[TyVarBndrUnit]] -> [[TyVarBndrUnit]]
forall a b. (a -> b) -> a -> b
$
(DatatypeInfo -> [TyVarBndrUnit])
-> [DatatypeInfo] -> [[TyVarBndrUnit]]
forall a b. (a -> b) -> [a] -> [b]
map
([TyVarBndrUnit] -> [TyVarBndrUnit]
forall a. [a] -> [a]
reverse ([TyVarBndrUnit] -> [TyVarBndrUnit])
-> (DatatypeInfo -> [TyVarBndrUnit])
-> DatatypeInfo
-> [TyVarBndrUnit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [TyVarBndrUnit] -> [TyVarBndrUnit]
forall a. Int -> [a] -> [a]
drop Int
numDroppedTailTypes ([TyVarBndrUnit] -> [TyVarBndrUnit])
-> (DatatypeInfo -> [TyVarBndrUnit])
-> DatatypeInfo
-> [TyVarBndrUnit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TyVarBndrUnit] -> [TyVarBndrUnit]
forall a. [a] -> [a]
reverse ([TyVarBndrUnit] -> [TyVarBndrUnit])
-> (DatatypeInfo -> [TyVarBndrUnit])
-> DatatypeInfo
-> [TyVarBndrUnit]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DatatypeInfo -> [TyVarBndrUnit]
datatypeVars)
[DatatypeInfo]
datatypes
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (([TyVarBndrUnit] -> Bool) -> [[TyVarBndrUnit]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all [TyVarBndrUnit] -> Bool
allSameKind [[TyVarBndrUnit]]
tyVars) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveWithHandlers: all type variables must be aligned"
[([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints <-
([([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> SomeDeriveTypeParamHandler
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> [SomeDeriveTypeParamHandler]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
((SomeDeriveTypeParamHandler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> SomeDeriveTypeParamHandler
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SomeDeriveTypeParamHandler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> SomeDeriveTypeParamHandler
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> (SomeDeriveTypeParamHandler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> SomeDeriveTypeParamHandler
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall a b. (a -> b) -> a -> b
$ Int
-> SomeDeriveTypeParamHandler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall handler.
DeriveTypeParamHandler handler =>
Int
-> handler
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> Q [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
handleTypeParams ([DatatypeInfo] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DatatypeInfo]
datatypes))
( ([TyVarBndrUnit] -> ([(TyVarBndrUnit, Maybe Type)], Maybe [Type]))
-> [[TyVarBndrUnit]]
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
forall a b. (a -> b) -> [a] -> [b]
map
(\[TyVarBndrUnit]
tyVarList -> ((,Maybe Type
forall a. Maybe a
Nothing) (TyVarBndrUnit -> (TyVarBndrUnit, Maybe Type))
-> [TyVarBndrUnit] -> [(TyVarBndrUnit, Maybe Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndrUnit]
tyVarList, Maybe [Type]
forall a. Maybe a
Nothing))
[[TyVarBndrUnit]]
tyVars
)
[SomeDeriveTypeParamHandler]
handlers
let allTyVarsConstraints :: [Type]
allTyVarsConstraints =
(([(TyVarBndrUnit, Maybe Type)], Maybe [Type]) -> [Type])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([Type] -> Maybe [Type] -> [Type]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Type] -> [Type])
-> (([(TyVarBndrUnit, Maybe Type)], Maybe [Type]) -> Maybe [Type])
-> ([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(TyVarBndrUnit, Maybe Type)], Maybe [Type]) -> Maybe [Type]
forall a b. (a, b) -> b
snd) [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints
let tvWithSubst :: [[(TyVarBndrUnit, Maybe Type)]]
tvWithSubst =
Int
-> [[(TyVarBndrUnit, Maybe Type)]]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall a. Int -> [[a]] -> [[a]]
transposeMatrix ([DatatypeInfo] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DatatypeInfo]
datatypes) ([[(TyVarBndrUnit, Maybe Type)]]
-> [[(TyVarBndrUnit, Maybe Type)]])
-> [[(TyVarBndrUnit, Maybe Type)]]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall a b. (a -> b) -> a -> b
$
([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [(TyVarBndrUnit, Maybe Type)]
forall a b. (a, b) -> a
fst (([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [(TyVarBndrUnit, Maybe Type)])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints
let substMaps :: [Map Name Type]
substMaps =
([(TyVarBndrUnit, Maybe Type)] -> Map Name Type)
-> [[(TyVarBndrUnit, Maybe Type)]] -> [Map Name Type]
forall a b. (a -> b) -> [a] -> [b]
map
( [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
([(Name, Type)] -> Map Name Type)
-> ([(TyVarBndrUnit, Maybe Type)] -> [(Name, Type)])
-> [(TyVarBndrUnit, Maybe Type)]
-> Map Name Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TyVarBndrUnit, Maybe Type) -> Maybe (Name, Type))
-> [(TyVarBndrUnit, Maybe Type)] -> [(Name, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe
( \(TyVarBndrUnit
tv, Maybe Type
t) -> do
Type
substTy <- Maybe Type
t
(Name, Type) -> Maybe (Name, Type)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (TyVarBndrUnit -> Name
forall flag. TyVarBndr_ flag -> Name
tvName TyVarBndrUnit
tv, Type
substTy)
)
)
[[(TyVarBndrUnit, Maybe Type)]]
tvWithSubst
let substedTypes :: [DatatypeInfo]
substedTypes = (DatatypeInfo -> Map Name Type -> DatatypeInfo)
-> [DatatypeInfo] -> [Map Name Type] -> [DatatypeInfo]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DatatypeInfo -> Map Name Type -> DatatypeInfo
substDataType [DatatypeInfo]
datatypes [Map Name Type]
substMaps
[Type]
tys <-
(DatatypeInfo -> Q Type) -> [DatatypeInfo] -> Q [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> Type -> Q Type
dropNTypeParam Int
numDroppedTailTypes (Type -> Q Type)
-> (DatatypeInfo -> Type) -> DatatypeInfo -> Q Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DatatypeInfo -> Type
datatypeType) [DatatypeInfo]
substedTypes
[Type]
allConstraints <-
( if Bool
ignoreBodyConstraints
then [Type] -> Q [Type]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
allTyVarsConstraints
else do
[Type]
bodyConstraints <- [DatatypeInfo] -> [SomeDeriveTypeParamHandler] -> Q [Type]
forall {t :: * -> *} {a}.
(Traversable t, DeriveTypeParamHandler a) =>
[DatatypeInfo] -> t a -> Q [Type]
handleBodyWithHandlers [DatatypeInfo]
substedTypes [SomeDeriveTypeParamHandler]
handlers
[Type] -> Q [Type]
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type] -> Q [Type]) -> [Type] -> Q [Type]
forall a b. (a -> b) -> a -> b
$ [Type]
allTyVarsConstraints [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
bodyConstraints
)
provider
-> [[(TyVarBndrUnit, Maybe Type)]] -> [Type] -> [Type] -> Q [Dec]
forall provider.
DeriveInstanceProvider provider =>
provider
-> [[(TyVarBndrUnit, Maybe Type)]] -> [Type] -> [Type] -> Q [Dec]
instanceDeclaration
provider
provider
(([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [(TyVarBndrUnit, Maybe Type)]
forall a b. (a, b) -> a
fst (([(TyVarBndrUnit, Maybe Type)], Maybe [Type])
-> [(TyVarBndrUnit, Maybe Type)])
-> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
-> [[(TyVarBndrUnit, Maybe Type)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [([(TyVarBndrUnit, Maybe Type)], Maybe [Type])]
tyVarsWithConstraints)
[Type]
allConstraints
[Type]
tys
where
handleBodyWithHandlers :: [DatatypeInfo] -> t a -> Q [Type]
handleBodyWithHandlers [DatatypeInfo]
datatypes t a
handlers = do
let cons :: [[ConstructorInfo]]
cons = DatatypeInfo -> [ConstructorInfo]
datatypeCons (DatatypeInfo -> [ConstructorInfo])
-> [DatatypeInfo] -> [[ConstructorInfo]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [DatatypeInfo]
datatypes
let zippedFields :: [[Type]]
zippedFields = [[ConstructorInfo]] -> [[Type]]
forall {t :: * -> *}. Foldable t => [t ConstructorInfo] -> [[Type]]
zipFields [[ConstructorInfo]]
cons
t [Type] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (t [Type] -> [Type]) -> Q (t [Type]) -> Q [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> Q [Type]) -> t a -> Q (t [Type])
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse (a -> [[Type]] -> Q [Type]
forall handler.
DeriveTypeParamHandler handler =>
handler -> [[Type]] -> Q [Type]
`handleBody` [[Type]]
zippedFields) t a
handlers
zipFields :: [t ConstructorInfo] -> [[Type]]
zipFields [t ConstructorInfo]
cons = [[Type]] -> [[Type]]
forall a. [[a]] -> [[a]]
transpose ([[Type]] -> [[Type]]) -> [[Type]] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ (ConstructorInfo -> [Type]) -> t ConstructorInfo -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ConstructorInfo -> [Type]
constructorFields (t ConstructorInfo -> [Type]) -> [t ConstructorInfo] -> [[Type]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [t ConstructorInfo]
cons