module Language.Haskell.TH.TypeGraph.TypeTraversal
( HasVisitedMap(..)
, doType
, HasTypeParameters(..)
, withBindings
, HasTypeTraversal(..)
, doApply
, FieldInfo(..)
, expandType, pprint1, pprint1', pprintW, pprintW'
, ToName(toName)
, HasMessageInfo(..), message, indent
) where
import Control.Lens (Lens', view)
import Control.Monad.RWS as Monad hiding (lift)
import Data.Generics (Data, everywhere, mkT)
import Data.List (intercalate)
import Data.Map.Strict as Map (Map, lookup)
import qualified Data.Map.Strict as Map (fromList)
import Instances.TH.Lift ()
import Language.Haskell.TH hiding (prim)
import Language.Haskell.TH.Desugar as DS (DsMonad, typeToTH, dsType, expand)
import Language.Haskell.TH.PprLib (to_HPJ_Doc)
import Language.Haskell.TH.Syntax as TH
import Language.Haskell.TH.TypeGraph.Orphans ()
import qualified Text.PrettyPrint as HPJ
class Monad m => HasTypeParameters m where
pushParam :: Type -> m a -> m a
withParams :: ([Type] -> m ()) -> m ()
data FieldInfo
= FieldInfo
{ _typeName :: Name
, _constrCount :: Int
, _constrIndex :: Int
, _constrName :: Name
, _fieldCount :: Int
, _fieldIndex :: Int
, _fieldName :: Maybe Name
, _fieldType :: Type
} deriving Show
class (DsMonad m, HasVisitedMap m) => HasTypeTraversal m where
prepType :: Type -> m Type
doTypeInternal :: Type -> m ()
doListT :: Type -> Type -> m ()
doTupleT :: Type -> Type -> Int -> m ()
doField :: Type -> (Type -> Type) -> FieldInfo -> m ()
doVarT :: Type -> Name -> m ()
doType :: HasTypeTraversal m => Type -> m ()
doType typ = prepType typ >>= doTypeOnce doTypeInternal
class DsMonad m => HasVisitedMap m where
unvisited :: Type -> m () -> m ()
doTypeOnce :: HasVisitedMap m => (Type -> m ()) -> Type -> m ()
doTypeOnce go typ = unvisited typ (go typ)
doApply :: (HasTypeTraversal m, HasTypeParameters m, DsMonad m) => Type -> Type -> m ()
doApply typ0 (ForallT _tvs _cxt typ) = doApply typ0 typ
doApply typ0 (VarT name) = doVarT typ0 name
doApply typ0 (AppT a b) = pushParam b (doApply typ0 a)
doApply typ0 (ConT tname) = qReify tname >>= doInfo typ0
doApply typ0 ListT = do
withParams $ \ps -> do
case ps of
[a] -> doListT typ0 a
_ -> error $ "ListT Arity error: "
doApply typ0 typ@(TupleT n) = do
withParams $ \ps -> do
when (length ps /= n) (error $ "Tuple Arity error: " ++
pprint1 typ ++
" [" ++ intercalate ", " (fmap pprint1 ps) ++ "]")
mapM_ (uncurry (doTupleT typ0)) (zip ps [1..])
doApply typ0 _ = do
error $ "doApply - unexpected type: " ++ pprint1 typ0
doInfo :: (DsMonad m, HasTypeParameters m, HasTypeTraversal m) => Type -> Info -> m ()
doInfo _typ0 (PrimTyConI _name _arity _unl) = pure ()
doInfo _typ0 (TyConI (TySynD _tname binds typ)) =
runQ (expandType typ) >>= \typ' ->
withBindings (\subst -> doType (subst typ')) binds
#if MIN_VERSION_template_haskell(2,11,0)
doInfo typ0 (TyConI (NewtypeD cx tname binds mk con supers)) =
doInfo typ0 (TyConI (DataD cx tname binds mk [con] supers))
doInfo typ0 (TyConI (DataD _ tname binds _mk cons _supers)) =
withBindings (\subst -> do mapM_ (uncurry (doCon typ0 tname subst (length cons))) (zip [1..] cons)) binds
#else
doInfo typ0 (TyConI (NewtypeD cx tname binds con supers)) =
doInfo typ0 (TyConI (DataD cx tname binds [con] supers))
doInfo typ0 (TyConI (DataD _ tname binds cons _supers)) =
withBindings (\subst -> do mapM_ (uncurry (doCon typ0 tname subst (length cons))) (zip [1..] cons)) binds
#endif
doInfo _ info = error $ "Unexpected info: " ++ pprint1 info ++ "\n\t" ++ show info
doCon :: (HasTypeParameters m, HasTypeTraversal m) => Type -> Name -> (Type -> Type) -> Int -> Int -> Con -> m ()
doCon typ0 tname subst cct cpos (ForallC _ _ con) = doCon typ0 tname subst cct cpos con
doCon typ0 tname subst cct cpos (RecC cname vsts) =
mapM_ (\(i, (fname, _, ftype)) ->
expandType ftype >>= \ftype' ->
let fld = FieldInfo { _typeName = tname
, _constrCount = cct
, _constrIndex = cpos
, _constrName = cname
, _fieldCount = length vsts
, _fieldIndex = i
, _fieldName = Just fname
, _fieldType = subst ftype' } in
doField typ0 subst fld) (zip [1..] vsts)
doCon typ0 tname subst cct cpos (NormalC cname sts) =
mapM_ (\(i, (_, ftype)) ->
expandType ftype >>= \ftype' ->
let fld = FieldInfo { _typeName = tname
, _constrCount = cct
, _constrIndex = cpos
, _constrName = cname
, _fieldCount = length sts
, _fieldIndex = i
, _fieldName = Nothing
, _fieldType = subst ftype' } in
doField typ0 subst fld) (zip [1..] sts)
doCon typ0 tname subst cct cpos (InfixC lhs cname rhs) =
mapM_ (\(i, (_, ftype)) ->
expandType ftype >>= \ftype' ->
let fld = FieldInfo { _typeName = tname
, _constrCount = cct
, _constrIndex = cpos
, _constrName = cname
, _fieldCount = 2
, _fieldIndex = i
, _fieldName = Nothing
, _fieldType = subst ftype' } in
doField typ0 subst fld) [(1, lhs), (2, rhs)]
withBindings :: (HasTypeParameters m, Data a) => ((a -> a) -> m ()) -> [TyVarBndr] -> m ()
withBindings action vars = do
withParams $ \vals -> do
when (length vals < length vars)
(error $ "doInfo - arity mismatch:\n\tvars=" ++ show vars ++
"\n\tparams=" ++ show vals)
let subst :: forall a. Data a => a -> a
subst = substG bindings
bindings = Map.fromList (zip (fmap toName vars) ( vals :: [Type]))
action subst
where
substG :: forall a. Data a => Map Name Type -> a -> a
substG bindings typ = everywhere (mkT (subst1 bindings)) typ
subst1 :: Map Name Type -> Type -> Type
subst1 bindings t@(VarT name) = maybe t id (Map.lookup name bindings)
subst1 _ t = t
pprint1 :: (Ppr a, Data a) => a -> [Char]
pprint1 = pprint1' . friendlyNames
pprint1' :: Ppr a => a -> [Char]
pprint1' = pprintStyle (HPJ.style {HPJ.mode = HPJ.OneLineMode})
pprintW :: (Ppr a, Data a) => Int -> a -> [Char]
pprintW w = pprintW' w . friendlyNames
pprintW' :: Ppr a => Int -> a -> [Char]
pprintW' w = pprintStyle (HPJ.style {HPJ.lineLength = w})
pprintStyle :: Ppr a => HPJ.Style -> a -> String
pprintStyle style = HPJ.renderStyle style . to_HPJ_Doc . ppr
friendlyNames :: Data a => a -> a
friendlyNames =
everywhere (mkT friendlyName)
where
friendlyName (Name x _) = Name x NameS
expandType :: DsMonad m => Type -> m Type
expandType typ = DS.typeToTH <$> (DS.dsType typ >>= DS.expand)
class ToName a where toName :: a -> Name
instance ToName TyVarBndr where
toName (PlainTV n) = n
toName (KindedTV n _) = n
instance ToName Con where
toName (ForallC _ _ con) = toName con
toName (NormalC cname _) = cname
toName (RecC cname _) = cname
toName (InfixC _ cname _) = cname
instance ToName VarStrictType where
toName (n, _, _) = n
class HasMessageInfo a where
verbosity' :: Lens' a Int
prefix' :: Lens' a String
message :: (Quasi m, MonadReader s m, HasMessageInfo s) =>
Int -> String -> m ()
message minv s = do
v <- view verbosity'
p <- view prefix'
when (v >= minv) $ (runQ . runIO . putStr . indent p) s
indent :: String -> String -> String
indent p s = unlines $ fmap (p ++) (lines s)