{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Data.GADT.TH.Internal where
import Control.Monad
import Control.Monad.Writer
import Data.List (foldl', drop)
import Data.Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr
classHeadToParams :: Type -> (Name, [Type])
classHeadToParams :: Type -> (Name, Cxt)
classHeadToParams Type
t = (Name
h, Cxt -> Cxt
forall a. [a] -> [a]
reverse Cxt
reversedParams)
where
(Name
h, Cxt
reversedParams) = Type -> (Name, Cxt)
go Type
t
go :: Type -> (Name, [Type])
go :: Type -> (Name, Cxt)
go Type
t = case Type
t of
AppT Type
f Type
x ->
let (Name
h, Cxt
reversedParams) = Type -> (Name, Cxt)
classHeadToParams Type
f
in (Name
h, Type
x Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
: Cxt
reversedParams)
Type
_ -> (Type -> Name
headOfType Type
t, [])
data family Skolem :: k -> k
skolemize :: Set Name -> Type -> Type
skolemize :: Set Name -> Type -> Type
skolemize Set Name
rigids Type
t = case Type
t of
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt Type
t' -> [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt (Set Name -> Type -> Type
skolemize (Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set Name
rigids ([Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList ((TyVarBndr Specificity -> Name)
-> [TyVarBndr Specificity] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr Specificity -> Name
forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs))) Type
t')
AppT Type
t1 Type
t2 -> Type -> Type -> Type
AppT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
SigT Type
t Type
k -> Type -> Type -> Type
SigT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t) Type
k
VarT Name
v -> if Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Name
v Set Name
rigids
then Type -> Type -> Type
AppT (Name -> Type
ConT ''Skolem) (Name -> Type
VarT Name
v)
else Type
t
InfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
InfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
UInfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
UInfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
ParensT Type
t -> Type -> Type
ParensT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t)
Type
_ -> Type
t
reifyInstancesWithRigids :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids :: Set Name -> Name -> Cxt -> Q [InstanceDec]
reifyInstancesWithRigids Set Name
rigids Name
cls Cxt
tys = Name -> Cxt -> Q [InstanceDec]
reifyInstances Name
cls ((Type -> Type) -> Cxt -> Cxt
forall a b. (a -> b) -> [a] -> [b]
map (Set Name -> Type -> Type
skolemize Set Name
rigids) Cxt
tys)
freeTypeVariables :: Type -> Set Name
freeTypeVariables :: Type -> Set Name
freeTypeVariables Type
t = case Type
t of
ForallT [TyVarBndr Specificity]
bndrs Cxt
_ Type
t' -> Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.difference (Type -> Set Name
freeTypeVariables Type
t') ([Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList ((TyVarBndr Specificity -> Name)
-> [TyVarBndr Specificity] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr Specificity -> Name
forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs))
AppT Type
t1 Type
t2 -> Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Type -> Set Name
freeTypeVariables Type
t1) (Type -> Set Name
freeTypeVariables Type
t2)
SigT Type
t Type
_ -> Type -> Set Name
freeTypeVariables Type
t
VarT Name
n -> Name -> Set Name
forall a. a -> Set a
Set.singleton Name
n
Type
_ -> Set Name
forall a. Set a
Set.empty
subst :: Map Name Type -> Type -> Type
subst :: Map Name Type -> Type -> Type
subst Map Name Type
s = Type -> Type
f
where
f :: Type -> Type
f = \case
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt Type
t ->
let s' :: Map Name Type
s' = Map Name Type -> Map Name () -> Map Name Type
forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.difference Map Name Type
s ([(Name, ())] -> Map Name ()
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name
k,()) | Name
k <- (TyVarBndr Specificity -> Name)
-> [TyVarBndr Specificity] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr Specificity -> Name
forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs])
in [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt (Map Name Type -> Type -> Type
subst Map Name Type
s' Type
t)
AppT Type
t Type
t' -> Type -> Type -> Type
AppT (Type -> Type
f Type
t) (Type -> Type
f Type
t')
SigT Type
t Type
k -> Type -> Type -> Type
SigT (Type -> Type
f Type
t) Type
k
VarT Name
n -> case Name -> Map Name Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
n Map Name Type
s of
Just Type
t -> Type
t
Maybe Type
Nothing -> Name -> Type
VarT Name
n
InfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
InfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
UInfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
UInfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
Type
x -> Type
x
deriveForDec
:: Name
-> (DatatypeInfo -> WriterT [Type] Q Dec)
-> Dec
-> Q [Dec]
deriveForDec :: Name
-> (DatatypeInfo -> WriterT Cxt Q InstanceDec)
-> InstanceDec
-> Q [InstanceDec]
deriveForDec Name
className DatatypeInfo -> WriterT Cxt Q InstanceDec
f (InstanceD Maybe Overlap
overlaps Cxt
cxt Type
instanceHead [InstanceDec]
decs) = do
let (Name
givenClassName, Type
firstParam : Cxt
_) = Type -> (Name, Cxt)
classHeadToParams Type
instanceHead
Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
givenClassName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Name
className) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
String -> Q ()
forall a. String -> Q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"while deriving " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
className String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": wrong class name in prototype declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
givenClassName
let dataTypeName :: Name
dataTypeName = Type -> Name
headOfType Type
firstParam
dataTypeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
dataTypeName
let instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
dataTypeInfo
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
instTypes' = case Cxt -> Cxt
forall a. [a] -> [a]
reverse Cxt
instTypes of
[] -> String -> Cxt
forall a. String -> [a]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
(Type
_:Cxt
xs) -> Cxt -> Cxt
forall a. [a] -> [a]
reverse Cxt
xs
generatedInstanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) ((Type -> Type -> Type) -> Type -> Cxt -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) Cxt
instTypes')
unifiedTypes <- unifyTypes [generatedInstanceHead, instanceHead]
let
newInstanceHead = Map Name Type -> Type -> Type
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Type
instanceHead
newContext = Map Name Type -> Cxt -> Cxt
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Cxt
cxt
(dec, _) <- runWriterT $ f dataTypeInfo
return [InstanceD overlaps newContext newInstanceHead [dec]]
deriveForDec Name
className DatatypeInfo -> WriterT Cxt Q InstanceDec
f InstanceDec
dataDec = do
dataTypeInfo <- InstanceDec -> Q DatatypeInfo
normalizeDec InstanceDec
dataDec
let instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
dataTypeInfo
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
instTypes' = case Cxt -> Cxt
forall a. [a] -> [a]
reverse Cxt
instTypes of
[] -> String -> Cxt
forall a. String -> [a]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
(Type
_:Cxt
xs) -> Cxt -> Cxt
forall a. [a] -> [a]
reverse Cxt
xs
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) ((Type -> Type -> Type) -> Type -> Cxt -> Type
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) Cxt
instTypes')
(dec, cxt') <- runWriterT (f dataTypeInfo)
return [InstanceD Nothing (datatypeContext dataTypeInfo ++ cxt') instanceHead [dec]]
headOfType :: Type -> Name
headOfType :: Type -> Name
headOfType = \case
ForallT [TyVarBndr Specificity]
_ Cxt
_ Type
ty -> Type -> Name
headOfType Type
ty
VarT Name
name -> Name
name
ConT Name
name -> Name
name
TupleT Int
n -> Int -> Name
tupleTypeName Int
n
Type
ArrowT -> ''(->)
Type
ListT -> ''[]
AppT Type
t Type
_ -> Type -> Name
headOfType Type
t