{-# LANGUAGE MultiWayIf #-}
module Data.Singletons.TH.Deriving.Util where
import Control.Monad
import Data.Singletons.TH.Names
import Data.Singletons.TH.Syntax
import Data.Singletons.TH.Util
import Language.Haskell.TH.Desugar
import qualified Language.Haskell.TH.Desugar.OSet as OSet
import Language.Haskell.TH.Syntax
type DerivDesc q
= Maybe DCxt
-> DType
-> DataDecl
-> q UInstDecl
data FFoldType a
= FT { forall a. FFoldType a -> a
ft_triv :: a
, forall a. FFoldType a -> a
ft_var :: a
, forall a. FFoldType a -> DType -> a -> a
ft_ty_app :: DType -> a -> a
, forall a. FFoldType a -> a
ft_bad_app :: a
, forall a. FFoldType a -> [DTyVarBndrSpec] -> a -> a
ft_forall :: [DTyVarBndrSpec] -> a -> a
}
functorLikeTraverse :: forall q a.
DsMonad q
=> Name
-> FFoldType a
-> DType
-> q a
functorLikeTraverse :: forall (q :: * -> *) a.
DsMonad q =>
Name -> FFoldType a -> DType -> q a
functorLikeTraverse Name
var (FT { ft_triv :: forall a. FFoldType a -> a
ft_triv = a
caseTrivial, ft_var :: forall a. FFoldType a -> a
ft_var = a
caseVar
, ft_ty_app :: forall a. FFoldType a -> DType -> a -> a
ft_ty_app = DType -> a -> a
caseTyApp, ft_bad_app :: forall a. FFoldType a -> a
ft_bad_app = a
caseWrongArg
, ft_forall :: forall a. FFoldType a -> [DTyVarBndrSpec] -> a -> a
ft_forall = [DTyVarBndrSpec] -> a -> a
caseForAll })
DType
ty
= do ty' <- DType -> q DType
forall (q :: * -> *). DsMonad q => DType -> q DType
expandType DType
ty
(res, _) <- go ty'
pure res
where
go :: DType
-> q (a, Bool)
go :: DType -> q (a, Bool)
go t :: DType
t@DAppT{} = do
let (DType
f, [DTypeArg]
args) = DType -> (DType, [DTypeArg])
unfoldDType DType
t
vis_args :: [DType]
vis_args = [DTypeArg] -> [DType]
filterDTANormals [DTypeArg]
args
(_, fc) <- DType -> q (a, Bool)
go DType
f
(xrs, xcs) <- mapAndUnzipM go vis_args
let wrongArg :: q (a, Bool)
wrongArg = (a, Bool) -> q (a, Bool)
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
caseWrongArg, Bool
True)
if | not (or xcs)
-> trivial
| fc || or (init xcs)
-> wrongArg
| otherwise
-> do itf <- isInTypeFamilyApp var f vis_args
if itf
then wrongArg
else pure (caseTyApp (last vis_args) (last xrs), True)
go (DAppKindT DType
t DType
k) = do
(_, kc) <- DType -> q (a, Bool)
go DType
k
if kc
then pure (caseWrongArg, True)
else go t
go (DSigT DType
t DType
k) = do
(_, kc) <- DType -> q (a, Bool)
go DType
k
if kc
then pure (caseWrongArg, True)
else go t
go (DVarT Name
v)
| Name
v Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
var = (a, Bool) -> q (a, Bool)
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
caseVar, Bool
True)
| Bool
otherwise = q (a, Bool)
trivial
go (DForallT DForallTelescope
tele DType
t) = case DForallTelescope
tele of
DForallVis{} ->
String -> q (a, Bool)
forall a. String -> q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unexpected visible forall in the type of a data constructor"
DForallInvis [DTyVarBndrSpec]
tvbs -> do
(tr, tc) <- DType -> q (a, Bool)
go DType
t
if var `notElem` map extractTvbName tvbs && tc
then pure (caseForAll tvbs tr, True)
else trivial
go (DConstrainedT [DType]
_ DType
t) = DType -> q (a, Bool)
go DType
t
go (DConT {}) = q (a, Bool)
trivial
go DType
DArrowT = q (a, Bool)
trivial
go (DLitT {}) = q (a, Bool)
trivial
go DType
DWildCardT = q (a, Bool)
trivial
trivial :: q (a, Bool)
trivial :: q (a, Bool)
trivial = (a, Bool) -> q (a, Bool)
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
caseTrivial, Bool
False)
isInTypeFamilyApp :: forall q. DsMonad q => Name -> DType -> [DType] -> q Bool
isInTypeFamilyApp :: forall (q :: * -> *).
DsMonad q =>
Name -> DType -> [DType] -> q Bool
isInTypeFamilyApp Name
name DType
tyFun [DType]
tyArgs =
case DType
tyFun of
DConT Name
tcName -> Name -> q Bool
go Name
tcName
DType
_ -> Bool -> q Bool
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
where
go :: Name -> q Bool
go :: Name -> q Bool
go Name
tcName = do
info <- Name -> q (Maybe DInfo)
forall (q :: * -> *). DsMonad q => Name -> q (Maybe DInfo)
dsReify Name
tcName
case info of
Just (DTyConI DDec
dec Maybe [DDec]
_)
| DOpenTypeFamilyD (DTypeFamilyHead Name
_ [DTyVarBndrVis]
bndrs DFamilyResultSig
_ Maybe InjectivityAnn
_) <- DDec
dec
-> [DTyVarBndrVis] -> q Bool
forall a. [a] -> q Bool
withinFirstArgs [DTyVarBndrVis]
bndrs
| DClosedTypeFamilyD (DTypeFamilyHead Name
_ [DTyVarBndrVis]
bndrs DFamilyResultSig
_ Maybe InjectivityAnn
_) [DTySynEqn]
_ <- DDec
dec
-> [DTyVarBndrVis] -> q Bool
forall a. [a] -> q Bool
withinFirstArgs [DTyVarBndrVis]
bndrs
Maybe DInfo
_ -> Bool -> q Bool
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
withinFirstArgs :: [a] -> q Bool
withinFirstArgs :: forall a. [a] -> q Bool
withinFirstArgs [a]
bndrs =
let firstArgs :: [DType]
firstArgs = Int -> [DType] -> [DType]
forall a. Int -> [a] -> [a]
take ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
bndrs) [DType]
tyArgs
argFVs :: OSet Name
argFVs = (DType -> OSet Name) -> [DType] -> OSet Name
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DType -> OSet Name
fvDType [DType]
firstArgs
in Bool -> q Bool
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> q Bool) -> Bool -> q Bool
forall a b. (a -> b) -> a -> b
$ Name
name Name -> OSet Name -> Bool
forall a. Eq a => a -> OSet a -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` OSet Name
argFVs
functorLikeValidityChecks :: forall q. DsMonad q => Bool -> DataDecl -> q ()
functorLikeValidityChecks :: forall (q :: * -> *). DsMonad q => Bool -> DataDecl -> q ()
functorLikeValidityChecks Bool
allowConstrainedLastTyVar (DataDecl DataFlavor
_df Name
n [DTyVarBndrVis]
data_tvbs [DCon]
cons)
| [DTyVarBndrVis] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DTyVarBndrVis]
data_tvbs
= String -> q ()
forall a. String -> q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> q ()) -> String -> q ()
forall a b. (a -> b) -> a -> b
$ String
"Data type " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
n String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" must have some type parameters"
| Bool
otherwise
= (DCon -> q ()) -> [DCon] -> q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ DCon -> q ()
check_con [DCon]
cons
where
check_con :: DCon -> q ()
check_con :: DCon -> q ()
check_con DCon
con = do
DCon -> q ()
check_universal DCon
con
checks <- FFoldType (q ()) -> DCon -> q [q ()]
forall (q :: * -> *) a. DsMonad q => FFoldType a -> DCon -> q [a]
foldDataConArgs (Name -> FFoldType (q ())
ft_check (DCon -> Name
extractName DCon
con)) DCon
con
sequence_ checks
check_universal :: DCon -> q ()
check_universal :: DCon -> q ()
check_universal (DCon [DTyVarBndrSpec]
_ [DType]
con_theta Name
con_name DConFields
_ DType
res_ty)
| Bool
allowConstrainedLastTyVar
= () -> q ()
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
| (DType
_, [DTypeArg]
res_ty_args) <- DType -> (DType, [DTypeArg])
unfoldDType DType
res_ty
, ([DType]
_, DType
last_res_ty_arg) <- [DType] -> ([DType], DType)
forall a. [a] -> ([a], a)
snocView ([DType] -> ([DType], DType)) -> [DType] -> ([DType], DType)
forall a b. (a -> b) -> a -> b
$ [DTypeArg] -> [DType]
filterDTANormals [DTypeArg]
res_ty_args
, Just Name
last_tv <- DType -> Maybe Name
getDVarTName_maybe DType
last_res_ty_arg
= do if Name
last_tv Name -> OSet Name -> Bool
forall a. Ord a => a -> OSet a -> Bool
`OSet.notMember` (DType -> OSet Name) -> [DType] -> OSet Name
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap DType -> OSet Name
fvDType [DType]
con_theta
then () -> q ()
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
else String -> q ()
forall a. String -> q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> q ()) -> String -> q ()
forall a b. (a -> b) -> a -> b
$ Name -> String -> String
badCon Name
con_name String
existential
| Bool
otherwise
= String -> q ()
forall a. String -> q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> q ()) -> String -> q ()
forall a b. (a -> b) -> a -> b
$ Name -> String -> String
badCon Name
con_name String
existential
ft_check :: Name -> FFoldType (q ())
ft_check :: Name -> FFoldType (q ())
ft_check Name
con_name =
FT { ft_triv :: q ()
ft_triv = () -> q ()
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
, ft_var :: q ()
ft_var = () -> q ()
forall a. a -> q a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
, ft_ty_app :: DType -> q () -> q ()
ft_ty_app = \DType
_ q ()
x -> q ()
x
, ft_bad_app :: q ()
ft_bad_app = String -> q ()
forall a. String -> q a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> q ()) -> String -> q ()
forall a b. (a -> b) -> a -> b
$ Name -> String -> String
badCon Name
con_name String
wrong_arg
, ft_forall :: [DTyVarBndrSpec] -> q () -> q ()
ft_forall = \[DTyVarBndrSpec]
_ q ()
x -> q ()
x
}
badCon :: Name -> String -> String
badCon :: Name -> String -> String
badCon Name
con_name String
msg = String
"Constructor " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameBase Name
con_name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
msg
existential, wrong_arg :: String
existential :: String
existential = String
"must be truly polymorphic in the last argument of the data type"
wrong_arg :: String
wrong_arg = String
"must use the type variable only as the last argument of a data type"
deepSubtypesContaining :: DsMonad q => Name -> DType -> q [DType]
deepSubtypesContaining :: forall (q :: * -> *). DsMonad q => Name -> DType -> q [DType]
deepSubtypesContaining Name
tv
= Name -> FFoldType [DType] -> DType -> q [DType]
forall (q :: * -> *) a.
DsMonad q =>
Name -> FFoldType a -> DType -> q a
functorLikeTraverse Name
tv
(FT { ft_triv :: [DType]
ft_triv = []
, ft_var :: [DType]
ft_var = []
, ft_ty_app :: DType -> [DType] -> [DType]
ft_ty_app = (:)
, ft_bad_app :: [DType]
ft_bad_app = String -> [DType]
forall a. HasCallStack => String -> a
error String
"in other argument in deepSubtypesContaining"
, ft_forall :: [DTyVarBndrSpec] -> [DType] -> [DType]
ft_forall = \[DTyVarBndrSpec]
tvbs [DType]
xs -> (DType -> Bool) -> [DType] -> [DType]
forall a. (a -> Bool) -> [a] -> [a]
filter (\DType
x -> (DTyVarBndrSpec -> Bool) -> [DTyVarBndrSpec] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (DType -> DTyVarBndrSpec -> Bool
not_in_ty DType
x) [DTyVarBndrSpec]
tvbs) [DType]
xs })
where
not_in_ty :: DType -> DTyVarBndrSpec -> Bool
not_in_ty :: DType -> DTyVarBndrSpec -> Bool
not_in_ty DType
ty DTyVarBndrSpec
tvb = DTyVarBndrSpec -> Name
forall flag. DTyVarBndr flag -> Name
extractTvbName DTyVarBndrSpec
tvb Name -> OSet Name -> Bool
forall a. Ord a => a -> OSet a -> Bool
`OSet.notMember` DType -> OSet Name
fvDType DType
ty
foldDataConArgs :: forall q a. DsMonad q => FFoldType a -> DCon -> q [a]
foldDataConArgs :: forall (q :: * -> *) a. DsMonad q => FFoldType a -> DCon -> q [a]
foldDataConArgs FFoldType a
ft (DCon [DTyVarBndrSpec]
_ [DType]
_ Name
_ DConFields
fields DType
res_ty) = do
field_tys <- (DType -> q DType) -> [DType] -> q [DType]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse DType -> q DType
forall (q :: * -> *). DsMonad q => DType -> q DType
expandType ([DType] -> q [DType]) -> [DType] -> q [DType]
forall a b. (a -> b) -> a -> b
$ DConFields -> [DType]
tysOfConFields DConFields
fields
traverse foldArg field_tys
where
foldArg :: DType -> q a
foldArg :: DType -> q a
foldArg
| (DType
_, [DTypeArg]
res_ty_args) <- DType -> (DType, [DTypeArg])
unfoldDType DType
res_ty
, ([DType]
_, DType
last_res_ty_arg) <- [DType] -> ([DType], DType)
forall a. [a] -> ([a], a)
snocView ([DType] -> ([DType], DType)) -> [DType] -> ([DType], DType)
forall a b. (a -> b) -> a -> b
$ [DTypeArg] -> [DType]
filterDTANormals [DTypeArg]
res_ty_args
, Just Name
last_tv <- DType -> Maybe Name
getDVarTName_maybe DType
last_res_ty_arg
= Name -> FFoldType a -> DType -> q a
forall (q :: * -> *) a.
DsMonad q =>
Name -> FFoldType a -> DType -> q a
functorLikeTraverse Name
last_tv FFoldType a
ft
| Bool
otherwise
= q a -> DType -> q a
forall a b. a -> b -> a
const (a -> q a
forall a. a -> q a
forall (m :: * -> *) a. Monad m => a -> m a
return (FFoldType a -> a
forall a. FFoldType a -> a
ft_triv FFoldType a
ft))
getDVarTName_maybe :: DType -> Maybe Name
getDVarTName_maybe :: DType -> Maybe Name
getDVarTName_maybe (DSigT DType
t DType
_) = DType -> Maybe Name
getDVarTName_maybe DType
t
getDVarTName_maybe (DVarT Name
n) = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
n
getDVarTName_maybe DType
_ = Maybe Name
forall a. Maybe a
Nothing
mkSimpleLam :: Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleLam :: forall (q :: * -> *). Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleLam DExp -> q DExp
lam = do
n <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"n"
body <- lam (DVarE n)
return $ dLamE [DVarP n] body
mkSimpleWildLam :: Quasi q => q DExp -> q DExp
mkSimpleWildLam :: forall (q :: * -> *). Quasi q => q DExp -> q DExp
mkSimpleWildLam q DExp
lam = do
body <- q DExp
lam
return $ dLamE [DWildP] body
mkSimpleLam2 :: Quasi q => (DExp -> DExp -> q DExp) -> q DExp
mkSimpleLam2 :: forall (q :: * -> *). Quasi q => (DExp -> DExp -> q DExp) -> q DExp
mkSimpleLam2 DExp -> DExp -> q DExp
lam = do
n1 <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"n1"
n2 <- newUniqueName "n2"
body <- lam (DVarE n1) (DVarE n2)
return $ dLamE [DVarP n1, DVarP n2] body
mkSimpleWildLam2 :: Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleWildLam2 :: forall (q :: * -> *). Quasi q => (DExp -> q DExp) -> q DExp
mkSimpleWildLam2 DExp -> q DExp
lam = do
n <- String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"n"
body <- lam (DVarE n)
return $ dLamE [DWildP, DVarP n] body
mkSimpleConClause :: Quasi q
=> (Name -> [DExp] -> DExp)
-> [DPat]
-> DCon
-> [DExp]
-> q DClause
mkSimpleConClause :: forall (q :: * -> *).
Quasi q =>
(Name -> [DExp] -> DExp) -> [DPat] -> DCon -> [DExp] -> q DClause
mkSimpleConClause Name -> [DExp] -> DExp
fold [DPat]
extra_pats (DCon [DTyVarBndrSpec]
_ [DType]
_ Name
con_name DConFields
_ DType
_) [DExp]
insides = do
vars_needed <- Int -> q Name -> q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([DExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DExp]
insides) (q Name -> q [Name]) -> q Name -> q [Name]
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"a"
let pat = Name -> [DType] -> [DPat] -> DPat
DConP Name
con_name [] ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
vars_needed)
rhs = Name -> [DExp] -> DExp
fold Name
con_name ((DExp -> Name -> DExp) -> [DExp] -> [Name] -> [DExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\DExp
i Name
v -> DExp
i DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
v) [DExp]
insides [Name]
vars_needed)
pure $ DClause (extra_pats ++ [pat]) rhs
isFunctorLikeClassName :: Name -> Bool
isFunctorLikeClassName :: Name -> Bool
isFunctorLikeClassName Name
class_name
= Name
class_name Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name
functorName, Name
foldableName, Name
traversableName]