{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.GADT.TH.Internal where
import Control.Monad
import Control.Monad.Writer
import Data.List (foldl', drop)
import Data.Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr
classHeadToParams :: Type -> (Name, [Type])
classHeadToParams :: Type -> (Name, Cxt)
classHeadToParams Type
t = (Name
h, forall a. [a] -> [a]
reverse Cxt
reversedParams)
where
(Name
h, Cxt
reversedParams) = Type -> (Name, Cxt)
go Type
t
go :: Type -> (Name, [Type])
go :: Type -> (Name, Cxt)
go Type
t = case Type
t of
AppT Type
f Type
x ->
let (Name
h, Cxt
reversedParams) = Type -> (Name, Cxt)
classHeadToParams Type
f
in (Name
h, Type
x forall a. a -> [a] -> [a]
: Cxt
reversedParams)
Type
_ -> (Type -> Name
headOfType Type
t, [])
data family Skolem :: k -> k
skolemize :: Set Name -> Type -> Type
skolemize :: Set Name -> Type -> Type
skolemize Set Name
rigids Type
t = case Type
t of
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt Type
t' -> [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt (Set Name -> Type -> Type
skolemize (forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set Name
rigids (forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs))) Type
t')
AppT Type
t1 Type
t2 -> Type -> Type -> Type
AppT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
SigT Type
t Type
k -> Type -> Type -> Type
SigT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t) Type
k
VarT Name
v -> if forall a. Ord a => a -> Set a -> Bool
Set.member Name
v Set Name
rigids
then Type -> Type -> Type
AppT (Name -> Type
ConT ''Skolem) (Name -> Type
VarT Name
v)
else Type
t
InfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
InfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
UInfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
UInfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
ParensT Type
t -> Type -> Type
ParensT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t)
Type
_ -> Type
t
reifyInstancesWithRigids :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids :: Set Name -> Name -> Cxt -> Q [InstanceDec]
reifyInstancesWithRigids Set Name
rigids Name
cls Cxt
tys = Name -> Cxt -> Q [InstanceDec]
reifyInstances Name
cls (forall a b. (a -> b) -> [a] -> [b]
map (Set Name -> Type -> Type
skolemize Set Name
rigids) Cxt
tys)
freeTypeVariables :: Type -> Set Name
freeTypeVariables :: Type -> Set Name
freeTypeVariables Type
t = case Type
t of
ForallT [TyVarBndr Specificity]
bndrs Cxt
_ Type
t' -> forall a. Ord a => Set a -> Set a -> Set a
Set.difference (Type -> Set Name
freeTypeVariables Type
t') (forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs))
AppT Type
t1 Type
t2 -> forall a. Ord a => Set a -> Set a -> Set a
Set.union (Type -> Set Name
freeTypeVariables Type
t1) (Type -> Set Name
freeTypeVariables Type
t2)
SigT Type
t Type
_ -> Type -> Set Name
freeTypeVariables Type
t
VarT Name
n -> forall a. a -> Set a
Set.singleton Name
n
Type
_ -> forall a. Set a
Set.empty
subst :: Map Name Type -> Type -> Type
subst :: Map Name Type -> Type -> Type
subst Map Name Type
s = Type -> Type
f
where
f :: Type -> Type
f = \case
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt Type
t ->
let s' :: Map Name Type
s' = forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.difference Map Name Type
s (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name
k,()) | Name
k <- forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs])
in [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt (Map Name Type -> Type -> Type
subst Map Name Type
s' Type
t)
AppT Type
t Type
t' -> Type -> Type -> Type
AppT (Type -> Type
f Type
t) (Type -> Type
f Type
t')
SigT Type
t Type
k -> Type -> Type -> Type
SigT (Type -> Type
f Type
t) Type
k
VarT Name
n -> case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
n Map Name Type
s of
Just Type
t -> Type
t
Maybe Type
Nothing -> Name -> Type
VarT Name
n
InfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
InfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
UInfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
UInfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
Type
x -> Type
x
deriveForDec
:: Name
-> (DatatypeInfo -> WriterT [Type] Q Dec)
-> Dec
-> Q [Dec]
deriveForDec :: Name
-> (DatatypeInfo -> WriterT Cxt Q InstanceDec)
-> InstanceDec
-> Q [InstanceDec]
deriveForDec Name
className DatatypeInfo -> WriterT Cxt Q InstanceDec
f (InstanceD Maybe Overlap
overlaps Cxt
cxt Type
instanceHead [InstanceDec]
decs) = do
let (Name
givenClassName, Type
firstParam : Cxt
_) = Type -> (Name, Cxt)
classHeadToParams Type
instanceHead
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
givenClassName forall a. Eq a => a -> a -> Bool
/= Name
className) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"while deriving " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Name
className forall a. [a] -> [a] -> [a]
++ String
": wrong class name in prototype declaration: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Name
givenClassName
let dataTypeName :: Name
dataTypeName = Type -> Name
headOfType Type
firstParam
DatatypeInfo
dataTypeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
dataTypeName
let instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
dataTypeInfo
paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
instTypes' :: Cxt
instTypes' = case forall a. [a] -> [a]
reverse Cxt
instTypes of
[] -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
(Type
_:Cxt
xs) -> forall a. [a] -> [a]
reverse Cxt
xs
generatedInstanceHead :: Type
generatedInstanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) Cxt
instTypes')
Map Name Type
unifiedTypes <- Cxt -> Q (Map Name Type)
unifyTypes [Type
generatedInstanceHead, Type
instanceHead]
let
newInstanceHead :: Type
newInstanceHead = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Type
instanceHead
newContext :: Cxt
newContext = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Cxt
cxt
(InstanceDec
dec, Cxt
_) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> WriterT Cxt Q InstanceDec
f DatatypeInfo
dataTypeInfo
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [InstanceDec] -> InstanceDec
InstanceD Maybe Overlap
overlaps Cxt
newContext Type
newInstanceHead [InstanceDec
dec]]
deriveForDec Name
className DatatypeInfo -> WriterT Cxt Q InstanceDec
f InstanceDec
dataDec = do
DatatypeInfo
dataTypeInfo <- InstanceDec -> Q DatatypeInfo
normalizeDec InstanceDec
dataDec
let instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
dataTypeInfo
paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
instTypes' :: Cxt
instTypes' = case forall a. [a] -> [a]
reverse Cxt
instTypes of
[] -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
(Type
_:Cxt
xs) -> forall a. [a] -> [a]
reverse Cxt
xs
instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) Cxt
instTypes')
(InstanceDec
dec, Cxt
cxt') <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (DatatypeInfo -> WriterT Cxt Q InstanceDec
f DatatypeInfo
dataTypeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [InstanceDec] -> InstanceDec
InstanceD forall a. Maybe a
Nothing (DatatypeInfo -> Cxt
datatypeContext DatatypeInfo
dataTypeInfo forall a. [a] -> [a] -> [a]
++ Cxt
cxt') Type
instanceHead [InstanceDec
dec]]
headOfType :: Type -> Name
headOfType :: Type -> Name
headOfType = \case
ForallT [TyVarBndr Specificity]
_ Cxt
_ Type
ty -> Type -> Name
headOfType Type
ty
VarT Name
name -> Name
name
ConT Name
name -> Name
name
TupleT Int
n -> Int -> Name
tupleTypeName Int
n
Type
ArrowT -> ''(->)
Type
ListT -> ''[]
AppT Type
t Type
_ -> Type -> Name
headOfType Type
t