module Language.Haskell.TypeTree
(
IsDatatype(..)
, Binding(..)
, guess
, ttReify
, ttReifyOpts
, ttLit
, ttLitOpts
, ttDescribe
, ttDescribeOpts
, Key
, Arity
, ttEdges
, ttConnComp
, Leaf(..)
, ReifyOpts(..)
, defaultOpts
) where
import Control.Monad
import Control.Monad.Reader
import Data.Graph
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Set as S
import Data.Tree
import Language.Haskell.TH hiding (Arity)
import Language.Haskell.TH.PprLib
import Language.Haskell.TH.Syntax hiding (Arity, lift)
import qualified Language.Haskell.TH.Syntax as TH
import Language.Haskell.TypeTree.CheatingLift
import Language.Haskell.TypeTree.Datatype
import Language.Haskell.TypeTree.Leaf
import Prelude.Compat
import qualified Text.PrettyPrint as HPJ
data ReifyOpts = ReifyOpts
{ expandPrim :: Bool
, terminals :: S.Set Name
} deriving (Show, Eq)
defaultOpts :: ReifyOpts
defaultOpts = ReifyOpts {expandPrim = False, terminals = mempty}
ttDescribe :: IsDatatype t => t -> ExpQ
ttDescribe = ttDescribeOpts defaultOpts
ttDescribeOpts :: IsDatatype t => ReifyOpts -> t -> ExpQ
ttDescribeOpts o n = do
tree <- ttReifyOpts o n
stringE $
HPJ.renderStyle
HPJ.Style
{HPJ.mode = HPJ.LeftMode, HPJ.lineLength = 0, HPJ.ribbonsPerLine = 5} $
to_HPJ_Doc $ treeDoc tree
ttLit :: IsDatatype t => t -> ExpQ
ttLit = liftTree <=< ttReify
type Key = (Name, [Type])
type Arity = Int
ttEdges :: IsDatatype t => t -> ExpQ
ttEdges name = do
tr <- ttReify name
sigE (listE $ map lift_ $ node tr) [t|[((Name, Arity), Key, [Key])]|]
where
lift_ ((x, n), y, zs) = [|(($(liftName x), n), $(tup y), $(listE $ map tup zs))|]
tup (n, t) = [|($(liftName n), $(listE $ map liftType t))|]
ttConnComp :: IsDatatype t => t -> ExpQ
ttConnComp name = [|stronglyConnComp $(ttEdges name)|]
node :: Tree Leaf -> [((Name, Arity), Key, [Key])]
node = nubBy (\(x, _, _) (y, _, _) -> x == y) . go
where
go (Node ty xs) =
(second length $ unCon ty, unCon ty, map (unCon . rootLabel) xs) : concatMap go xs
second f (a, b) = (a, f b)
unCon :: Leaf -> (Name, [Type])
unCon (TypeL (x, y)) = (unBinding x, y)
unCon (Recursive r) = unCon r
ttLitOpts :: IsDatatype t => ReifyOpts -> t -> ExpQ
ttLitOpts opts = liftTree <=< ttReifyOpts opts
liftTree :: Lift t => Tree t -> ExpQ
liftTree (Node n xs) = [|Node $(TH.lift n) $(listE $ map liftTree xs)|]
data ReifyEnv = ReifyEnv
{ typeEnv :: Map Name Type
, nodes :: S.Set (Binding, [Type])
} deriving (Show)
ttReify :: IsDatatype t => t -> Q (Tree Leaf)
ttReify = ttReifyOpts defaultOpts
ttReifyOpts :: IsDatatype t => ReifyOpts -> t -> Q (Tree Leaf)
ttReifyOpts opts t = do
(a, b) <- asDatatype t
fromJust <$> runReaderT (go a b) (ReifyEnv mempty mempty)
where
go n args = do
go' n args
go' v@(Unbound n) gargs
| n `S.member` terminals opts = pure $ Just (Node (TypeL (v, gargs)) [])
| otherwise =
withVisit v gargs $ \givenArgs ->
Just . Node (TypeL (Unbound n, givenArgs)) <$>
mapMaybeM (uncurry resolve . unwrap) givenArgs
go' v@(Bound n) gargs
| n `S.member` terminals opts = pure $ Just (Node (TypeL (v, gargs)) [])
| otherwise =
withVisit v gargs $ \givenArgs -> do
dec <- lift $ reify n
case dec of
PrimTyConI n' _ _
| expandPrim opts || n' == ''(->) ->
Just . Node (TypeL (v, givenArgs)) <$>
mapMaybeM (uncurry resolve . unwrap) givenArgs
| otherwise -> pure Nothing
TyConI x -> processDec x n givenArgs
FamilyI _ insts ->
case findMatchingInstance givenArgs insts of
Just dec -> processDec dec n givenArgs
Nothing ->
fail $
"sorry, I cannot find a data/type instance " ++
"in scope which matches: " ++
show (treeDoc (Node (TypeL (v, givenArgs)) []))
DataConI {} -> badInput "a data constructor"
ClassOpI {} -> badInput "a class method"
ClassI {} -> badInput "a class name"
#if MIN_VERSION_template_haskell(2,12,0)
PatSynI {} -> badInput "a pattern synonym"
#endif
TyVarI {} ->
badInput "an unbound type variable (how did you get here?)"
VarI {} -> badInput "an ordinary value"
badInput s = fail $ "ttReify expects a type constructor, but was given " ++ s
processDec x n givenArgs = do
let (_, wantedArgs) = decodeHead givenArgs x
cons <- decodeBody x
withReaderT (\m -> foldr instantiate m $ zip wantedArgs givenArgs) $
do
if length givenArgs < length wantedArgs
then do
vars <-
lift $ sequence (fillVar <$> drop (length givenArgs) wantedArgs)
go (Bound n) (givenArgs ++ vars)
else Just . Node (TypeL (Bound n, givenArgs)) <$>
mapMaybeM (uncurry resolve) cons
mapMaybeM m xs = catMaybes <$> mapM m xs
fillVar (VarT n) = VarT <$> newName (nameBase n)
fillVar x = pure x
simplify r@ReifyEnv {typeEnv = te} (VarT n) =
case M.lookup n te of
Just ty -> simplify r ty
Nothing -> VarT n
simplify _ x@ConT {} = x
simplify r (AppT x y) = AppT (simplify r x) (simplify r y)
simplify _ x@TupleT {} = x
simplify _ x@UnboxedTupleT {} = x
simplify _ ListT = ListT
simplify _ ArrowT = ArrowT
simplify r (SigT t k) = SigT (simplify r t) k
simplify _ x = error $ show x
decodeHead _ (DataInstD _ n tys _ _ _) = (n, tys)
decodeHead _ (DataD _ n holes _ cons _)
| any isGadtCon cons = (n, [])
| otherwise = (n, map unTV holes)
decodeHead _ (NewtypeD _ n holes _ _ _) = (n, map unTV holes)
decodeHead _ (TySynD n holes _) = (n, map unTV holes)
decodeHead _ (TySynInstD n (TySynEqn holes _)) = (n, holes)
decodeHead _ x = error $ "decodeHead " ++ show x
decodeBody (DataD _ decName _ _ cons _) = concat <$> mapM (getFieldTypes decName) cons
decodeBody (DataInstD _ decName _ _ cons _) =
concat <$> mapM (getFieldTypes decName) cons
decodeBody (NewtypeD _ decName _ _ con _) = getFieldTypes decName con
decodeBody (TySynD _ _ ty) = pure [unwrap ty]
decodeBody (TySynInstD _ (TySynEqn _ ty)) = pure [unwrap ty]
decodeBody x = error $ "decodeBody " ++ show x
findMatchingInstance typeArgs (d@(DataInstD _ _ tys _ _ _):ds)
| matchesTypeInstance typeArgs tys = Just d
| otherwise = findMatchingInstance typeArgs ds
findMatchingInstance typeArgs (d@(TySynInstD _ (TySynEqn lhs _)):ds)
| matchesTypeInstance typeArgs lhs = Just d
| otherwise = findMatchingInstance typeArgs ds
findMatchingInstance _ [] = Nothing
findMatchingInstance _ _ =
error "FamilyI contained a Dec of the wrong type, this shouldn't happen"
getFieldTypes _ (NormalC _ xs) = pure $ map (\(_, y) -> unwrap y) xs
getFieldTypes _ (RecC _ xs) = pure $ map (\(_, _, y) -> unwrap y) xs
getFieldTypes _ (InfixC (_, a) nm (_, b))
| nameBase nm == ":" = pure [unwrap a]
| otherwise = pure [unwrap a, unwrap b]
getFieldTypes decName (GadtC _ fs ret) =
case unwrap ret of
(retN, retTys)
| retN == Bound decName ->
pure $ map (\(_, y) -> unwrap y) fs ++ map unwrap retTys
| otherwise ->
fail $
"sorry, GADT constructor return type must exactly " ++
"match datatype (this is a limitation in type-tree)"
getFieldTypes decName (ForallC _ _ cn) = getFieldTypes decName cn
getFieldTypes _ x = error $ show x
isGadtCon GadtC {} = True
isGadtCon RecGadtC {} = True
isGadtCon (ForallC _ _ c) = isGadtCon c
isGadtCon _ = False
unTV (KindedTV n _) = VarT n
unTV (PlainTV n) = VarT n
instantiate (VarT x, y) r@ReifyEnv {typeEnv = t} = r {typeEnv = M.insert x y t}
instantiate (AppT a b, AppT c d) r = instantiate (a, c) (instantiate (b, d) r)
instantiate _ r = r
withVisit a b m = do
r@ReifyEnv {nodes = nodes'} <- ask
let b' = map (simplify r) b
a' =
case simplify
r
(case a of
Bound x -> ConT x
Unbound x -> VarT x) of
ConT n -> Bound n
VarT n -> Unbound n
_ -> undefined
if S.member (a', b') nodes'
then pure $ Just $ Node (Recursive $ TypeL (a', b')) []
else withReaderT (\q -> q {nodes = S.insert (a', b') (nodes q)}) $ m b'
resolve (Bound x) args = go (Bound x) args
resolve (Unbound x) args = go' x args []
where
go' x' args' xs = do
m <- asks typeEnv
case M.lookup x' m of
Just (VarT y)
| elem y xs ->
pure $ Just $ Node (Recursive $ TypeL (Unbound x', args')) []
| otherwise -> go' y args' (y : xs)
Just (unwrap -> (h, args'')) -> go h (args'' ++ args')
Nothing -> go (Unbound x') args'
matchesTypeInstance [] [] = True
matchesTypeInstance xs (VarT _:ys) = matchesTypeInstance (drop 1 xs) ys
matchesTypeInstance (ConT x:xs) (ConT y:ys)
| x == y = matchesTypeInstance xs ys
| otherwise = False
matchesTypeInstance (AppT a b:xs) (AppT c d:ys) =
matchesTypeInstance [a] [c] &&
matchesTypeInstance [b] [d] && matchesTypeInstance xs ys
matchesTypeInstance (x:xs) (y:ys) = x == y && matchesTypeInstance xs ys
matchesTypeInstance _ _ = False