module Data.Functor.Invariant.TH.Internal where
import Data.Function (on)
import Data.List
import qualified Data.Map as Map (fromList, findWithDefault)
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax
#ifndef CURRENT_PACKAGE_KEY
import Data.Version (showVersion)
import Paths_invariant (version)
#endif
expandSyn :: Type -> Q Type
expandSyn (ForallT tvs ctx t) = fmap (ForallT tvs ctx) $ expandSyn t
expandSyn t@AppT{} = expandSynApp t []
expandSyn t@ConT{} = expandSynApp t []
expandSyn (SigT t _) = expandSyn t
expandSyn t = return t
expandSynApp :: Type -> [Type] -> Q Type
expandSynApp (AppT t1 t2) ts = do
t2' <- expandSyn t2
expandSynApp t1 (t2':ts)
expandSynApp (ConT n) ts | nameBase n == "[]" = return $ foldl' AppT ListT ts
expandSynApp t@(ConT n) ts = do
info <- reify n
case info of
TyConI (TySynD _ tvs rhs) ->
let (ts', ts'') = splitAt (length tvs) ts
subs = mkSubst tvs ts'
rhs' = subst subs rhs
in expandSynApp rhs' ts''
_ -> return $ foldl' AppT t ts
expandSynApp t ts = do
t' <- expandSyn t
return $ foldl' AppT t' ts
type Subst = Map Name Type
mkSubst :: [TyVarBndr] -> [Type] -> Subst
mkSubst vs ts =
let vs' = map un vs
un (PlainTV v) = v
un (KindedTV v _) = v
in Map.fromList $ zip vs' ts
subst :: Subst -> Type -> Type
subst subs (ForallT v c t) = ForallT v c $ subst subs t
subst subs t@(VarT n) = Map.findWithDefault t n subs
subst subs (AppT t1 t2) = AppT (subst subs t1) (subst subs t2)
subst subs (SigT t k) = SigT (subst subs t) k
subst _ t = t
data InvariantClass = Invariant | Invariant2
deriving (Eq, Ord)
instance Enum InvariantClass where
fromEnum Invariant = 1
fromEnum Invariant2 = 2
toEnum 1 = Invariant
toEnum 2 = Invariant2
toEnum i = error $ "No Invariant class for number " ++ show i
invmapConstName :: InvariantClass -> Name
invmapConstName Invariant = invmapConstValName
invmapConstName Invariant2 = invmap2ConstValName
invariantClassName :: InvariantClass -> Name
invariantClassName Invariant = invariantTypeName
invariantClassName Invariant2 = invariant2TypeName
invmapName :: InvariantClass -> Name
invmapName Invariant = invmapValName
invmapName Invariant2 = invmap2ValName
invmapConst :: f b -> (a -> b) -> (b -> a) -> f a -> f b
invmapConst = const . const . const
invmap2Const :: f c d
-> (a -> c) -> (c -> a)
-> (b -> d) -> (d -> b)
-> f a b -> f c d
invmap2Const = const . const . const . const . const
newtype NameBase = NameBase { getName :: Name }
getNameBase :: NameBase -> String
getNameBase = nameBase . getName
instance Eq NameBase where
(==) = (==) `on` getNameBase
instance Ord NameBase where
compare = compare `on` getNameBase
instance Show NameBase where
showsPrec p = showsPrec p . getNameBase
type TyVarInfo = (NameBase, Name, Name)
fst3 :: (a, b, c) -> a
fst3 (a, _, _) = a
thd3 :: (a, b, c) -> c
thd3 (_, _, c) = c
lookup2 :: Eq a => a -> [(a, b, c)] -> Maybe (b, c)
lookup2 _ [] = Nothing
lookup2 key ((x,y,z):xyzs)
| key == x = Just (y, z)
| otherwise = lookup2 key xyzs
constructorName :: Con -> Name
constructorName (NormalC name _ ) = name
constructorName (RecC name _ ) = name
constructorName (InfixC _ name _ ) = name
constructorName (ForallC _ _ con) = constructorName con
newNameList :: String -> Int -> Q [Name]
newNameList prefix n = mapM (newName . (prefix ++) . show) [1..n]
removeForalled :: [TyVarBndr] -> [TyVarInfo] -> [TyVarInfo]
removeForalled tvbs = filter (not . foralled tvbs)
where
foralled :: [TyVarBndr] -> TyVarInfo -> Bool
foralled tvbs' tvi = fst3 tvi `elem` map (NameBase . tvbName) tvbs'
tvbName :: TyVarBndr -> Name
tvbName (PlainTV name) = name
tvbName (KindedTV name _) = name
tvbKind :: TyVarBndr -> Kind
tvbKind (PlainTV _) = starK
tvbKind (KindedTV _ k) = k
replaceTyVarName :: TyVarBndr -> Type -> TyVarBndr
replaceTyVarName tvb (SigT t _) = replaceTyVarName tvb t
replaceTyVarName (PlainTV _) (VarT n) = PlainTV n
replaceTyVarName (KindedTV _ k) (VarT n) = KindedTV n k
replaceTyVarName tvb _ = tvb
applyClass :: Name -> Name -> Pred
#if MIN_VERSION_template_haskell(2,10,0)
applyClass con t = AppT (ConT con) (VarT t)
#else
applyClass con t = ClassP con [VarT t]
#endif
canEtaReduce :: [Type] -> [Type] -> Bool
canEtaReduce remaining dropped =
all isTyVar dropped
&& allDistinct nbs
&& not (any (`mentionsNameBase` nbs) remaining)
where
nbs :: [NameBase]
nbs = map varTToNameBase dropped
varTToName :: Type -> Name
varTToName (VarT n) = n
varTToName (SigT t _) = varTToName t
varTToName _ = error "Not a type variable!"
varTToNameBase :: Type -> NameBase
varTToNameBase = NameBase . varTToName
unSigT :: Type -> Type
unSigT (SigT t _) = t
unSigT t = t
isTyVar :: Type -> Bool
isTyVar (VarT _) = True
isTyVar (SigT t _) = isTyVar t
isTyVar _ = False
isTyFamily :: Type -> Q Bool
isTyFamily (ConT n) = do
info <- reify n
return $ case info of
#if MIN_VERSION_template_haskell(2,11,0)
FamilyI OpenTypeFamilyD{} _ -> True
#elif MIN_VERSION_template_haskell(2,7,0)
FamilyI (FamilyD TypeFam _ _ _) _ -> True
#else
TyConI (FamilyD TypeFam _ _ _) -> True
#endif
#if MIN_VERSION_template_haskell(2,9,0)
FamilyI ClosedTypeFamilyD{} _ -> True
#endif
_ -> False
isTyFamily _ = return False
allDistinct :: Ord a => [a] -> Bool
allDistinct = allDistinct' Set.empty
where
allDistinct' :: Ord a => Set a -> [a] -> Bool
allDistinct' uniqs (x:xs)
| x `Set.member` uniqs = False
| otherwise = allDistinct' (Set.insert x uniqs) xs
allDistinct' _ _ = True
mentionsNameBase :: Type -> [NameBase] -> Bool
mentionsNameBase = go Set.empty
where
go :: Set NameBase -> Type -> [NameBase] -> Bool
go foralls (ForallT tvbs _ t) nbs =
go (foralls `Set.union` Set.fromList (map (NameBase . tvbName) tvbs)) t nbs
go foralls (AppT t1 t2) nbs = go foralls t1 nbs || go foralls t2 nbs
go foralls (SigT t _) nbs = go foralls t nbs
go foralls (VarT n) nbs = varNb `elem` nbs && not (varNb `Set.member` foralls)
where
varNb = NameBase n
go _ _ _ = False
predMentionsNameBase :: Pred -> [NameBase] -> Bool
#if MIN_VERSION_template_haskell(2,10,0)
predMentionsNameBase = mentionsNameBase
#else
predMentionsNameBase (ClassP _ tys) nbs = any (`mentionsNameBase` nbs) tys
predMentionsNameBase (EqualP t1 t2) nbs = mentionsNameBase t1 nbs || mentionsNameBase t2 nbs
#endif
numKindArrows :: Kind -> Int
numKindArrows k = length (uncurryKind k) 1
applyTy :: Type -> [Type] -> Type
applyTy = foldl' AppT
applyTyCon :: Name -> [Type] -> Type
applyTyCon = applyTy . ConT
unapplyTy :: Type -> [Type]
unapplyTy = reverse . go
where
go :: Type -> [Type]
go (AppT t1 t2) = t2:go t1
go (SigT t _) = go t
go t = [t]
uncurryTy :: Type -> [Type]
uncurryTy (AppT (AppT ArrowT t1) t2) = t1:uncurryTy t2
uncurryTy (SigT t _) = uncurryTy t
uncurryTy t = [t]
uncurryKind :: Kind -> [Kind]
#if MIN_VERSION_template_haskell(2,8,0)
uncurryKind = uncurryTy
#else
uncurryKind (ArrowK k1 k2) = k1:uncurryKind k2
uncurryKind k = [k]
#endif
wellKinded :: [Kind] -> Bool
wellKinded = all canRealizeKindStar
canRealizeKindStarChain :: Kind -> Bool
canRealizeKindStarChain = all canRealizeKindStar . uncurryKind
canRealizeKindStar :: Kind -> Bool
canRealizeKindStar k = case uncurryKind k of
[k'] -> case k' of
#if MIN_VERSION_template_haskell(2,8,0)
StarT -> True
(VarT _) -> True
#else
StarK -> True
#endif
_ -> False
_ -> False
createKindChain :: Int -> Kind
createKindChain = go starK
where
go :: Kind -> Int -> Kind
go k 0 = k
#if MIN_VERSION_template_haskell(2,8,0)
go k n = n `seq` go (AppT (AppT ArrowT StarT) k) (n 1)
#else
go k n = n `seq` go (ArrowK StarK k) (n 1)
#endif
distinctKindVars :: Kind -> Set Name
#if MIN_VERSION_template_haskell(2,8,0)
distinctKindVars (AppT k1 k2) = distinctKindVars k1 `Set.union` distinctKindVars k2
distinctKindVars (SigT k _) = distinctKindVars k
distinctKindVars (VarT k) = Set.singleton k
#endif
distinctKindVars _ = Set.empty
tvbToType :: TyVarBndr -> Type
tvbToType (PlainTV n) = VarT n
tvbToType (KindedTV n k) = SigT (VarT n) k
invariantPackageKey :: String
#ifdef CURRENT_PACKAGE_KEY
invariantPackageKey = CURRENT_PACKAGE_KEY
#else
invariantPackageKey = "invariant-" ++ showVersion version
#endif
mkInvariantName_tc :: String -> String -> Name
mkInvariantName_tc = mkNameG_tc invariantPackageKey
mkInvariantName_v :: String -> String -> Name
mkInvariantName_v = mkNameG_v invariantPackageKey
invariantTypeName :: Name
invariantTypeName = mkInvariantName_tc "Data.Functor.Invariant" "Invariant"
invariant2TypeName :: Name
invariant2TypeName = mkInvariantName_tc "Data.Functor.Invariant" "Invariant2"
invmapValName :: Name
invmapValName = mkInvariantName_v "Data.Functor.Invariant" "invmap"
invmap2ValName :: Name
invmap2ValName = mkInvariantName_v "Data.Functor.Invariant" "invmap2"
invmapConstValName :: Name
invmapConstValName = mkInvariantName_v "Data.Functor.Invariant.TH.Internal" "invmapConst"
invmap2ConstValName :: Name
invmap2ConstValName = mkInvariantName_v "Data.Functor.Invariant.TH.Internal" "invmap2Const"
errorValName :: Name
errorValName = mkNameG_v "base" "GHC.Err" "error"