{-# LANGUAGE TemplateHaskell, CPP #-} {-# OPTIONS_GHC -W -Wall #-} {-| This is the code for \"Template Your Boilerplate\" under review at the Haskell Symposium 2012. A draft copy of that paper is available at and provides more thorough documentation. -} #define BENCHMARK 0 module Data.Generics.TH ( -- * Primitives thcase, thcase', thfoldl -- * Single-layer traversals , thmapT, thmapM, thmapQ, thmapQl, thmapQr -- * Memoization , memoizeDec, memoizeDec2 -- TODO: 3, 4, 5, etc. , memoizeExp, memoizeExp2 -- TOOD: 3, 4, 5, etc. -- * Traversals -- ** Transformations , everywhere, everywhere', everywhereBut , everywhereM, everywhereM', everywhereButM' , everywhereFor, everywhereForM , somewhere, somewhereM -- ** Queries , everything, everythingBut , everythingAccL, everythingAccL', everythingButAccL, everythingButAccL' , everythingAccR, everythingButAccR , everythingForR, everythingForL, everythingForL' -- * Extentions and adaptors , extN, extE, extE' , mkT, mkTs, mkQ, mkQs, mkM, mkMs -- * Type manipulation functions , eqType, eqTypes , inType, inTypes , constructorsOf, typeOfName #if BENCHMARK -- * Benchmarking implementations , everything_slow , everywhereM_slow #endif ) where -- Imports for the 'seen' table in 'inType' import Control.Monad.State -- Imports for memoization tables import Data.IORef import qualified Data.Map as Map import qualified Data.Set as Set -- Imports for Template Haskell import Language.Haskell.TH hiding (cxt{-avoid warnings when cxt is a variable-}) import Language.Haskell.TH.Syntax hiding (lift{-conflicts with monadic 'lift'-}) -- Import for listing the 'primitive types'. import Data.Int (Int8, Int16, Int32, Int64) import Data.Word (Word, Word8, Word16, Word32, Word64) import Data.Array (Array) import Foreign.Ptr (Ptr) import Foreign.ForeignPtr (ForeignPtr) -- Imports of other TYB modules import Data.Generics.TH.Instances () import Data.Generics.TH.VarSet (varSet) -------------------------------------------------------------------------------- -- thcase', thcase and thfoldl -------------------------------------------------------------------------------- -- |Primitive case expression generation. Most users will want to use -- 'thcase' instead. thcase' :: (Quasi m) => (Either Name {- prim :: t -} (Name {- ctor :: a -> b -> t -}, [(Type, Name)] {- args :: [a, b] -}) -> m Exp {- c t -}) -- ^ Case handling function. If the 'Type' being inspected is -- a primitive type, argument is @'Left' var@ where @var@ is a -- variable bound to the case discriminant. Otherwise, -- argument is @'Right' (ctor, args)@ where @ctor@ is the -- constructor name and @args@ is a list of the argument's -- types and the variable bound to the argument. -> m Type -- ^ The type to inspect. -> m Exp -- ^ The expression containing the @case@ -- expression. If the type to inspect is @t@ and -- the type of the 'Exp' returned by the case -- handling function is @r@, the 'Exp' returned by -- @thcase'@ is of type @t -> r@. thcase' g t0 = do ctors <- constructorsOf =<< t0 x <- qNewName "_x" body <- case ctors of Nothing -> g (Left x) Just ctors' -> return (CaseE (VarE x)) `ap` mapM doClause ctors' return $ LamE [VarP x] body where doClause (cons, ts) = do xs <- mapM (const $ qNewName "arg") ts --sequence [do x <- qNewName "arg"; return (t, x) | t <- ts] body <- g (Right (cons, zip ts xs)) return $ Match (ConP cons (map VarP xs)) (NormalB body) [] -- |Case expression generation. This is the core function of the -- Template Your Boilerplate library. -- -- This function is similar to @thcase'@, except that since most users -- will note care about the distinction between types and primitive -- types, this function smooths over the differences by treating primitive -- types as types with nullary constructors. thcase :: (Quasi m) => (m Exp -> [(Type, m Exp)] -> m Exp) -- ^ Case handling function. The first argument is the constructor. -- The second argument is the list of arguments and their types. -> m Type -- ^ The type to inspect. -> m Exp -- ^ The expression containing the @case@ expression. If -- the type to inspect is @t@ and the type of the 'Exp' -- returned by the case handling function is @r@, the -- 'Exp' returned by @thcase@ is of type @t -> r@. thcase g t0 = thcase' g' t0 where g' (Left var) = g (return (VarE var)) [] g' (Right (name, args)) = g (return (ConE name)) [(t, return (VarE n)) | (t, n) <- args] -- |Scrap Your Boilerplate style case expression generation. The -- 'thcase' function is generally simpler to use instead of this and -- is more powerful. thfoldl :: Quasi m => (m Exp -> Type -> m Exp -> m Exp) -- ^ Constructor argument application. If the first 'Exp' is of -- type @c (a -> b)@, the 'Type' is @a@, and the second 'Exp' is of -- type @a@, should return an 'Exp' of type @c b@. -> (m Exp -> m Exp) -- ^ Constructor injection. The argument 'Exp' -- will be one of the constructors from the type -- to be inspected. If the argument 'Exp' is of -- type @a@, should return an 'Exp' of type @c -- a@. -> m Type -- ^ The type to inspect. -> m Exp -- ^ The expression containing the @case@ expression. If -- the type to inspect is @t@ and the type of the 'Exp' -- returned by the case handling function is @c t@, the -- 'Exp' returned by @thcase@ is of type @t -> c t@. thfoldl k z t = thcase g t where g ctor args = foldl (uncurry . k) (z ctor) args -------------------------------------------------------------------------------- -- Single-layer traversals -------------------------------------------------------------------------------- -- |Generic single-layer transformation thmapT :: (Quasi m) => (Type -> m Exp) -- ^ The transformation. If the 'Type' is @t@, must -- return an 'Exp' of type @t -> t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each immediate child. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> t0@. thmapT f t0 = thcase g t0 where g ctor [] = ctor g ctor ((t, x) : xs) = do ctor' <- ctor f' <- f t x' <- x g (return (AppE ctor' (AppE f' x'))) xs -- g [|$(ctor) ($(f t) $(x))|] xs -- |Generic single-layer query. thmapQ :: (Quasi m) => (Type -> m Exp) -- ^ The query. Extracts data from the given -- type. If the 'Type' is @t@, must return an 'Exp' -- of type @t -> a@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the query -- to each immediate child. If the 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> [a]@. thmapQ query topTy = thcase g topTy where g _ctor [] = return (ConE '[]) g ctor ((t, x) : xs) = do f' <- query t rest <- g ctor xs x' <- x return (AppE (AppE (ConE '(:)) (AppE f' x')) rest) -- [| $(f t) $(x) : $(g ctor xs) |] -- |Generic single-layer query (right associative). thmapQr :: (Quasi m) => m Exp -- ^ Combining function. 'Exp' must have type @r' -> r -> r@ -> m Exp -- ^ Starting value. 'Exp' must have type @r@. -> (Type -> m Exp) -- ^ The query. Extract data from the given -- type. If the 'Type' is @t@, must return an 'Exp' -- of type @t -> r'@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the query -- to each immediate child and uses the -- starting value and combining functions to -- fold the query results. If the 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r@. thmapQr combine zero query t0 = thcase g t0 where g _ctor [] = zero g ctor ((t, x) : xs) = do combine' <- combine left <- query t x' <- x right <- g ctor xs return (AppE (AppE combine' (AppE left x')) right) -- [|$(combine) ($(query t) $x) $(g ctor xs)|] -- |Generic single-layer query (left associative). thmapQl :: (Quasi m) => (m Exp) -- ^ Combining function. 'Exp' must have type @r -> r' -> r@ -> (m Exp) -- ^ Starting value. 'Exp' must have type @r@. -> (Type -> m Exp) -- ^ The query. Extract data from the given -- type. If the 'Type' is @t@, must return an 'Exp' -- of type @t -> r'@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the query -- to each immediate child and uses the -- starting value and combining functions to -- fold the query results. If the 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r@. thmapQl combine zero query t0 = thcase g' t0 where g' _ctor xs = zero >>= \acc -> g acc xs g acc [] = return acc g acc ((t, x) : xs) = do combine' <- combine left <- query t x' <- x g (AppE (AppE combine' acc) (AppE left x')) xs --g [|$combine $acc ($(query t) $x)|] xs #if BENCHMARK thmapM_slow :: (Type -> Q Exp) -> Q Type -> Q Exp thmapM_slow f t0 = thfoldl k z t0 where z ctor = [| return $(ctor) |] k ctor t arg = [| $ctor >>= \c -> $(f t) $(arg) >>= \a -> return (c a) |] #endif -- | Generic single-layer monadic transformation. thmapM :: (Quasi m) => (Type -> m Exp) -- ^ The monadic transformation. If the 'Type' is -- @t@, must return an 'Exp' of type @t -> m t@ -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- monadic transformation to each immediate -- child. If the 'Type' is @t0@, returns an 'Exp' -- of type @t0 -> m t0@. thmapM f t0 = thcase g t0 where g ctor [] = liftM (AppE (VarE 'return)) ctor -- [| return $(ctor) |] g ctor ((t, x) : xs) = do f' <- f t x' <- x xprime <- qNewName "x'" body <- g (do ctor' <- ctor; return (AppE ctor' (VarE xprime))) xs return $ InfixE (Just $ AppE f' x') (VarE '(>>=)) (Just $ LamE [VarP xprime] body) -- [| $(f t) $(x) >>= \x' -> $(g [| $(ctor) x' |] xs) |] -------------------------------------------------------------------------------- -- Memoization -------------------------------------------------------------------------------- -- internal helper function memoizeDec' :: (Quasi m, Ord a) => IORef (Map.Map a Name) -- ^ Name table. -> IORef [(Name, Exp)] -- ^ Body table. -> (a -> m Exp) -- ^ Function to be memoized. -> (a -> m Exp) -- ^ Memoized version of the function. memoizeDec' nameRef bodyRef f1 a = do names <- qRunIO (readIORef nameRef) case Map.lookup a names of Just name -> return $ VarE name Nothing -> do name <- qNewName $ "memoized" -- ++ map (\x -> if isAlpha x then x else '_') (pprint a) qRunIO $ modifyIORef nameRef (Map.insert a name) body <- f1 a qRunIO $ modifyIORef bodyRef ((name, body):) return $ VarE name {-| Memoizes a code generation function. Most users will want to use 'memoizeExp' instead as it provides a simplified interface, but all the notes about this function also apply to 'memoizeExp'. We memoize a function returning an 'Exp' by creating a 'Dec' with a body that is the 'Exp' returned by that function. The return value of the function is replaced with a 'VarE' that refers to the 'Dec'. This allows functions like 'everywhere' to avoid infinite recursions when they traverse recursive types like lists. The memoization functions come in two flavors: 'memoizeDec' and 'memoizeExp'. With 'memoizeDec' it is the responsibility of the caller to place the 'Dec' in an appropriate place. The 'memoizeExp' function automatically handles the 'Dec' by wrapping them in a local 'LetE' form. Every memoized function is passed a memoized version of itself. This is the function that should be used in recursive calls. Failing to do so will prevent those calls from being memoized. Mutually recursive functions are possible using 'memoizeDec2', etc. and 'memoizeExp2', etc. If the function being memoized needs to accept multiple arguments, then they must be packed into a tuple and passed as a single argument. Effects in the @m@ monad are only performed the first time the memoized function is called with a particular argument. Subsequent times the monad is simply the result of a 'return'. Thus while it is tempting to store extra return values in the monad, this should be avoided due to the high likelihood of unexpected behavior. Implementation Notes: * Note that @m@ should not store a copy of the function, otherwise a memory leak is introduced. It wouldn't even make sense to do it anyway since the results refer to expressions that might not be in scope. * The memoized function stores a reference to the memoization table, Thus if a reference to the memoized function gets tucked inside @m@, then a memory leak can be introduced. We could eliminate this leak by clearing and invalidating the table when 'memoizeDec' returns. To fully do this properly the table would have to be invalidated in such a way that the memoized version of the function would not continue to try populating the table if the user called it after 'memoizeDec' return. * Conceptually we should use a State monad instead of an IORef but we choose IORef since we can embed IO operations in a Quasi without imposing extra restrictions on @m@. * Other designs are possible. This design was choosen for its simplicity of use. The choice of memoization interface is largely orthogonal to the rest of this library. * Type synonyms and kind annotations may lead to duplicate versions of the code (e.g. versions for both 'String' and @['Char']@) Usually this isn't a problem, but if it is, then the type synonyms should be expanded before each call to the memoized function. * GADTs and data/type families haven't been considered in this code. It is unknown whether they work. Note that polymorphically recursive types (e.g. @data F a = N a | F (F (Int, a))@) have an infinite number of types in them and thus despite memoization this function will not terminate on those types. -} memoizeDec :: (Quasi m, Ord a) => ((a -> m Exp) -> a -> m Exp) -- ^ The function to memoize. Takes -- a memoized version of the -- function as argument. -> a -- ^ The initial argument to the function. -> m ([Dec], Exp) -- ^ The result of applying the -- function to the initial argument. -- The |Exp| is the result, but -- expects the @['Dec']@ to be in -- scope. memoizeDec f1 a = do nameRef1 <- qRunIO $ newIORef Map.empty decsRef1 <- qRunIO $ newIORef [] let f1' = memoizeDec' nameRef1 decsRef1 (f1 f1') expr <- f1' a decs1 <- qRunIO $ readIORef decsRef1 return $ filterDecs decs1 expr -- Unreferenced variables may cause ambiguous types (e.g. "foo = return") -- that fun afoul of the monomorphism restriction. -- TODO: filter out recursive definitions by using SCC filterDecs :: [(Name, Exp)] -> Exp -> ([Dec], Exp) filterDecs decs expr = (decs', expr) where decs' = map (\(name, body) -> ValD (VarP name) (NormalB body) []) $ --decs filter (\x -> Set.member (fst x) vars) $ decs vars = Set.unions (varSet expr : map (varSet . snd) decs) -- | Simultaneously memoizes two code generation functions. All of -- the notes about 'memoizeDec' also apply to this function. Most -- users will want to use 'memoizeExp2' instead of this function as it -- provides a simplified interface. memoizeDec2 :: (Quasi m, Ord a, Ord b) => ((a -> m Exp) -> (b -> m Exp) -> a -> m Exp) -- ^ The first function to memoize. Takes memoized versions of the -- two functions as arguments. -> ((a -> m Exp) -> (b -> m Exp) -> b -> m Exp) -- ^ The second function to memoize. Takes memoized versions of the -- two functions as arguments. -> a -- ^ The initial argument. -> m ([Dec], Exp) -- ^ The result of applying the function to the -- initial argument. The |Exp| is the result, but -- expects the @['Dec']@ to be in scope. memoizeDec2 f1 f2 a = do nameRef1 <- qRunIO $ newIORef Map.empty nameRef2 <- qRunIO $ newIORef Map.empty decsRef1 <- qRunIO $ newIORef [] decsRef2 <- qRunIO $ newIORef [] let f1' = memoizeDec' nameRef1 decsRef1 (f1 f1' f2') f2' = memoizeDec' nameRef2 decsRef2 (f2 f1' f2') expr <- f1' a decs1 <- qRunIO $ readIORef decsRef1 decs2 <- qRunIO $ readIORef decsRef2 return $ filterDecs (decs1 ++ decs2) expr -- |Memoizes a code generation function. Behaves identically to -- 'memoizeDec' except that it returns a 'LetE' that binds the 'Dec' -- resulting from 'memoizeDec' for the 'Exp' resulting from -- 'memoizeDec'. memoizeExp :: (Quasi m, Ord a) => ((a -> m Exp) -> a -> m Exp) -> a -> m Exp memoizeExp f1 a = liftM (uncurry LetE) (memoizeDec f1 a) -- |Simultaneously memoizes two code generation functions. Behaves -- identically to 'memoizeDec2' except that it returns a 'LetE' that -- binds the 'Dec' resulting from 'memoizeDec2' for the 'Exp' -- resulting from 'memoizeDec2'. memoizeExp2 :: (Quasi m, Ord a, Ord b) => ((a -> m Exp) -> (b -> m Exp) -> a -> m Exp) -> ((a -> m Exp) -> (b -> m Exp) -> b -> m Exp) -> a -> m Exp memoizeExp2 f1 f2 a = liftM (uncurry LetE) (memoizeDec2 f1 f2 a) -------------------------------------------------------------------------------- -- Transform traversals -------------------------------------------------------------------------------- -- |Generic recursive transformation (bottom-up) everywhere :: (Quasi m) => (Type -> m Exp) -- ^ The transformation. If the 'Type' is @t@, must -- return an 'Exp' of type @t -> t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each descendant. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> t0@. everywhere f t0 = t0 >>= memoizeExp rec where rec r t = composeE (f t) (thmapT r (return t)) -- |Generic recursive transformation (top-down) everywhere' :: (Quasi m) => (Type -> m Exp) -- ^ The transformation. If the 'Type' is @t@, must -- return an 'Exp' of type @t -> t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each descendant. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> t0@. everywhere' f t0 = t0 >>= memoizeExp rec where rec r t = composeE (thmapT r (return t)) (f t) -- |Generic recursive transformation (bottom-up) with selective traversal. -- Skips traversal when a given query returns 'True'. everywhereBut :: (Quasi m) => (Type -> m Bool) -- ^ The query. Should return 'True' when a -- given type should not be traversed. -> (Type -> m Exp) -- ^ The transformation. If the 'Type' is @t@, -- must return an 'Exp' of type @t -> t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each descendant (except -- for parts skipped due to the query returning -- 'True'). If the 'Type' is @t0@, returns an -- 'Exp' of type @t0 -> t0@. everywhereBut q f t0 = t0 >>= memoizeExp rec where rec k t = do q' <- q t if q' then return (VarE 'id) else composeE (f t) (thmapT k (return t)) -- |Generic recursive monadic transformation (bottom-up) everywhereM :: (Quasi m) => (Type -> m Exp) -- ^ The monadic transformation. If the 'Type' is @t@, must -- return an 'Exp' of type @t -> m t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each descendant. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> m t0@. everywhereM f t0 = t0 >>= memoizeExp rec where rec r t = composeM (f t) (thmapM r (return t)) #if BENCHMARK everywhereM_slow :: (Type {- forall a. -} -> Q Exp {- a -> a -}) -> Q Type {- forall a. -} -> Q Exp {- a -> a -} everywhereM_slow f t0 = t0 >>= memoizeExp rec where rec k t = composeM (f t) (thmapM_slow k (return t)) #endif -- | Generic recursive monadic transformation (top-down) everywhereM' :: (Quasi m) => (Type -> m Exp) -- ^ The monadic transformation. If the 'Type' is @t@, must -- return an 'Exp' of type @t -> m t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each descendant. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> m t0@. everywhereM' f t0 = t0 >>= memoizeExp rec where rec k t = composeM (thmapM k (return t)) (f t) -- |Generic recursive monadic transformation (top-down) with selective traversal. -- Skips traversal when a given query returns 'True'. everywhereButM' :: (Quasi m) => (Type -> m Bool) -- ^ The query. Should return 'True' when a -- given type should not be traversed. -> (Type -> m Exp) -- ^ The monadic transformation. If the 'Type' is @t@, -- must return an 'Exp' of type @t -> m t@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- monadic transformation to each descendant (except -- for parts skipped due to the query returning -- 'True'). If the 'Type' is @t0@, returns an -- 'Exp' of type @t0 -> m t0@. everywhereButM' q f t0 = t0 >>= memoizeExp rec where rec k t = do q' <- q t if q' then return (VarE 'return) else composeM (thmapM k (return t)) (f t) -- |Generic recursive transformation (bottom-up) with selective -- traversal. Recurs on only types that can contain a type with type -- specific behavior. everywhereFor :: (Quasi m) => Name -- ^ Name of a function of type @t -> t@ -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- transformation to each descendant that is of -- type @t@. If the 'Type' is @t0@, returns an -- 'Exp' of type @t0 -> t0@. everywhereFor func topTy = do everywhereBut (liftM not . inType (argTypeOfName func)) (const (return (VarE 'id)) `extE` (eqType (argTypeOfName func), return (VarE func))) topTy -- |Generic recursive monadic transformation (bottom-up) with -- selective traversal. Recurs on only types that can contain a type -- with type specific behavior. everywhereForM :: (Quasi m) => Name -- ^ Name of a function of type @t -> m t@ -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the -- monadic transformation to each descendant -- that is of type @t@. If the 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> m t0@. everywhereForM func topTy = do everywhereButM' (liftM not . inType (argTypeOfName func)) (const (return (VarE 'return)) `extE` (eqType (argTypeOfName func), return (VarE func))) topTy {- This doesn't work because of the following: Dec <---+ / \ | ... ... | / \ | [Needed] Exp-+ It is impossible to know when processing Exp whether [Needed] will be hit because answering that requires processing all of the children of Dec which then requires Exp, etc. Memoization does not help this problem. -- skip useless recursions -- Names: everywhereNeeded, everywhereCan, ??? -- effective, useful, isJust, May, Can -- We use a state monad instead of a writer monad -- because the memoization doesn't replay old effects everywhere_ :: (Quasi m) => (Type {- forall a. -} -> m (Maybe Exp) {- a -> a -}) -> m Type {- forall a. -} -> m Exp {- a -> a -} everywhere_ f t = do t' <- t; (e, _) <- runStateT (memoizeExp rec t') (Set.empty, Set.empty); return e where --rec :: (Type -> WriterT Any Q Exp) -> Type -> WriterT Any Q Exp rec k t = do r <- thmapT k' (return t) f' <- lift $ f t when (isJust f') (setSelf t) needsChildren <- needChildren t case (needsChildren, f') of (False, Nothing) -> return (VarE 'id) (True, Nothing) -> return r (False, Just f'') -> return f'' (True, Just f'') -> lift $ composeE (return f'') (return r) where k' t' = do e <- k t' needSelf t' >>= \x -> when x (setChildren t) return e needSelf t = gets (Set.member t . fst) needChildren t = gets (Set.member t . snd) setSelf t = modify (\(s1, s2) -> (Set.insert t s1, s2)) setChildren t = modify (\(s1, s2) -> (s1, Set.insert t s1)) everywhereM'_ :: (Quasi m) => (Type {- forall a. -} -> m (Maybe Exp) {- a -> a -}) -> m Type {- forall a. -} -> m Exp {- a -> a -} everywhereM'_ f t = do t' <- t; (e, s) <- runStateT (memoizeExp rec t') (Set.empty, Set.empty); trace (show (Set.toList (fst s), Set.toList (snd s))) $ return e where --rec :: (Type -> WriterT Any Q Exp) -> Type -> WriterT Any Q Exp rec1 k t = do f' <- lift $ f t when (isJust f') (setSelf t f') thmapM k (return t) rec2 k t = do r <- thmapM k' (return t) f' <- lift $ f t needsChildren <- needChildren t when needsChildren (setSelf t) -- TODO: requires "Just"ness of "f" to be idempotent where k' t' = needSelf t' >>= \x -> when x (setChildren t) >> k t' {- rec k t = do r <- thmapM k' (return t) f' <- lift $ f t needsChildren <- needChildren t trace (" ****** " ++ show (isJust f') ++ show needsChildren ++ show t) $ return () when (isJust f' || needsChildren) (setSelf t) case (needsChildren, f') of (False, Nothing) -> return (VarE 'return) (True, Nothing) -> return r (False, Just f'') -> return f'' (True, Just f'') -> lift $ composeM (return r) (return f'') where k' t' = do e <- k t' needSelf t' >>= \x -> when x (setChildren t) return e needSelf t = gets (Set.member t . fst) needChildren t = gets (Set.member t . snd) setSelf t = modify (\(s1, s2) -> (Set.insert t s1, s2)) setChildren t = modify (\(s1, s2) -> (s1, Set.insert t s1)) -} -} {- -- alternative interface to everywhereNeeded everywhereBut' :: (Quasi m) => (Type -> m Bool) -> (Type {- forall a. -} -> m Exp {- a -> a -}) -> Type {- forall a. -} -> m Exp {- a -> a -} everywhereBut' q f t0 = do (e, _) <- runWriterT (memoizeExp rec t0); return e where -- rec :: (Quasi m) => (Type -> WriterT Any m Exp) -> Type -> WriterT Any m Exp rec k t = do (r, Any childRec) <- listen (thmapT k (return t)) q' <- lift $ q t tell (Any (not q')) case (childRec, q') of (False, True) -> return (VarE 'id) (True, True) -> return r (False, False) -> lift $ f t (True, False) -> lift $ composeE (return r) (f t) -} -- |Generic recursive transformation (bottom-up) with selective traversal. somewhere :: (Quasi m) => ((Type -> m Exp) -> (Type -> m (Maybe Exp))) -- ^ The transformation. The first argument is the memoized -- recursion. If 'Nothing' is returned, then the standard, -- automatic recursion is done. If 'Just' is returned, then no -- automatic recursion is done and the resulting 'Exp' is used at -- that type. In that case, if further recursion is desired, then -- the expression should include a call to the memoized recursion. -- If the 'Type' is @t@, then the returned 'Exp' must be of type @t -- -> t@. -- -- We use 'Maybe' instead of 'MonadPlus' to avoid the user having to -- play games with 'runMaybeT' and so forth. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the transformation to each -- descendant. If 'Type' is @t0@, returns an 'Exp' of type @t0 -> -- t0@. somewhere f t0 = t0 >>= memoizeExp rec where rec k t = do f' <- f k t case f' of Nothing -> thmapT k (return t) -- `mplus` return (VarE 'id) Just e -> return e -- |Generic recursive monadic transformation (bottom-up) with selective traversal. somewhereM :: (Quasi m) => ((Type -> m Exp) -> (Type -> m (Maybe Exp))) -- ^ The monadic transformation. The first argument is the memoized -- recursion. If 'Nothing' is returned, then the standard, -- automatic recursion is done. If 'Just' is returned, then no -- automatic recursion is done and the resulting 'Exp' is used at -- that type. In that case, if further recursion is desired, then -- the expression should include a call to the memoized recursion. -- If the 'Type' is @t@, then the returned 'Exp' must be of type @t -- -> m t@. -- -- We use 'Maybe' instead of 'MonadPlus' to avoid the user having to -- play games with 'runMaybeT' and so forth. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the transformation to each -- descendant. If 'Type' is @t0@, returns an 'Exp' of type @t0 -> -- m t0@. somewhereM f t0 = t0 >>= memoizeExp rec where rec k t = do f' <- f k t case f' of Nothing -> thmapM k (return t) Just e -> return e -------------------------------------------------------------------------------- -- Queries -------------------------------------------------------------------------------- -- |Generic recursive query (bottom-up). everything :: (Quasi m) => m Exp -- ^ Combining function. 'Exp' must have type @r -> r -> r@. -> (Type -> m Exp) -- ^ The query. Extract data from the given -- type. If the 'Type' is @t@, must return an 'Exp' -- of type @t -> r'@. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the query -- to each descendant and uses the combining -- function to combine the query results. If -- the 'Type' is @t0@, returns an 'Exp' of type -- @t0 -> r@. everything collect query t0 = t0 >>= memoizeExp rec where rec r t = do --[|\x -> $(thmapQl collect [|$(query t) x|] r (return t)) x|] x <- qNewName "x" f <- query t mapQ <- thmapQl collect (return (AppE f (VarE x))) r (return t) return (LamE [VarP x] (AppE mapQ (VarE x))) #if BENCHMARK everything_slow :: (Quasi m) => (m Exp) -- Combine collections, expr :: c a -> c a -> c a -> (Type -> m Exp) -- Extract values, expr :: d -> c a -> m Type -- The top type, the first 'd'. -> m Exp -- expr :: d -> c a everything_slow collect query topTy = collect >>= \cExp -> topTy >>= memoizeExp (ething cExp) where ething cExp memoF currTy = do f_curr <- query currTy f_rest <- thmapQ memoF (return currTy) x <- qNewName "x" return (LamE [VarP x] (AppE (AppE (AppE (VarE 'foldl) cExp) (AppE f_curr (VarE x))) (AppE f_rest (VarE x)))) #endif -- |Generic recursive query with left-associative accumulation. everythingAccL :: (Type -> Q Exp) -- ^ The query and combining function. If the -- 'Type' is @t@, must return an 'Exp' of type @t -- -> r -> r@. -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r -> r@ everythingAccL query t0 = t0 >>= memoizeExp rec where rec memoF currTy = do e <- query currTy [|\x acc -> $(return e) x ($(thcase (g [|acc|]) (return currTy)) x)|] where -- Left accumulator g acc _ctor [] = acc g acc ctor ((t, v):vs) = [|$(memoF t) $v $(g acc ctor vs)|] -- |Generic recursive query with strict left-associative accumulation everythingAccL' :: (Type -> Q Exp) -- ^ The query and combining function. If the -- 'Type' is @t@, must return an 'Exp' of type @t -- -> r -> r@. -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r -> r@ everythingAccL' query t0 = t0 >>= memoizeExp rec where rec memoF currTy = do e <- query currTy [|\x acc -> $(return e) x ($(thcase (g [|acc|]) (return currTy)) x)|] where -- Left accumulator g acc _ctor [] = acc g acc ctor ((t, v):vs) = [|$acc `seq` $(memoF t) $v $(g acc ctor vs)|] -- |Generic recursive query with left-associative accumulation and selective traversal everythingButAccL :: (Type -> Q (Exp, Bool)) -- ^ The query, combining, selectivity function. If the 'Type' is -- @t@, must return an 'Exp' of type @t -> r -> r@. If the 'Bool' -- is 'True' the traversal does not proceed further into the -- recursion. -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r -> r@ everythingButAccL query t0 = t0 >>= memoizeExp rec where rec memoF currTy = do (e, stop) <- query currTy if stop then return e else [|\x acc -> $(return e) x ($(thcase (g [|acc|]) (return currTy)) x)|] where -- Left accumulator g acc _ctor [] = acc g acc ctor ((t, v):vs) = [|$(memoF t) $v $(g acc ctor vs)|] -- |Generic recursive query with strict left-associative accumulation and selective traversal everythingButAccL' :: (Type -> Q (Exp, Bool)) -- ^ The query, combining, selectivity function. If the 'Type' is -- @t@, must return an 'Exp' of type @t -> r -> r@. If the 'Bool' -- is 'True' the traversal does not proceed further into the -- recursion. -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r -> r@ everythingButAccL' query t0 = t0 >>= memoizeExp rec where rec memoF currTy = do (e, stop) <- query currTy if stop then return e else [|\x acc -> $(return e) x ($(thcase (g [|acc|]) (return currTy)) x)|] where -- Left accumulator g acc _ctor [] = acc g acc ctor ((t, v):vs) = [|$acc `seq` $(memoF t) $v $(g acc ctor vs)|] -- |Generic recursive query with right-associative accumulation everythingAccR :: (Type -> Q Exp) -- ^ The query and combining function. If the -- 'Type' is @t@, must return an 'Exp' of type @t -- -> r -> r@. -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r -> r@ everythingAccR query t0 = t0 >>= memoizeExp rec where rec memoF currTy = do e <- query currTy [|\x acc -> $(return e) x ($(thcase (g [|acc|]) (return currTy)) x)|] where -- Right accumulator g acc _ctor [] = acc g acc ctor ((t, v):vs) = g [|$(memoF t) $v $acc|] ctor vs -- |Generic recursive query with right-associative accumulation and selective traversal everythingButAccR :: (Type -> Q (Exp, Bool)) -- ^ The query, combining, selectivity function. If the 'Type' is -- @t@, must return an 'Exp' of type @t -> r -> r@. If the 'Bool' -- is 'True' the traversal does not proceed further into the -- recursion. -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r -> r@ everythingButAccR query t0 = t0 >>= memoizeExp rec where rec memoF currTy = do (e, stop) <- query currTy if stop then return e else [|\x acc -> $(return e) x ($(thcase (g [|acc|]) (return currTy)) x)|] where -- Right accumulator g acc _ctor [] = acc g acc ctor ((t, v):vs) = g [|$(memoF t) $v $acc|] ctor vs -- |Generic recursive query with selective traversal everythingBut :: (Quasi m) => (m Exp) -- ^ Combining function. 'Exp' must have type @r -> r -> r@. -> (Type -> m (Exp,Bool)) -- ^ The query, combining, selectivity function. If the 'Type' is -- @t@, must return an 'Exp' of type @t -> r -> r@. If the 'Bool' -- is 'True' the traversal does not proceed further into the -- recursion. -> (m Type -> m Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant. If 'Type' is @t0@, -- returns an 'Exp' of type @t0 -> r@ everythingBut collect query topTy = topTy >>= memoizeExp rec where rec memoF currTy = do x <- qNewName "x" (f_curr,b) <- query currTy if b then return f_curr else do mapQ <- thmapQl collect (return (AppE f_curr (VarE x))) memoF (return currTy) return (LamE [VarP x] (AppE mapQ (VarE x))) -- |Generic recursive traversal using left-associative accumulation everythingForL :: Name -- ^ Name of a function of type @t -> r -> r@ -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant that is of type @t@. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> r -> r@ everythingForL func topTy = do a1 <- argTypeOfName func let op = VarE func g t = do b <- eqType (return a1) t c <- inType (return a1) t if b then return (op, not c) else return (AppE (VarE 'const) (VarE 'id), not c) everythingButAccL g topTy -- |Generic recursive traversal using strict left-associative accumulation everythingForL' :: Name -- ^ Name of a function of type @t -> r -> r@ -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant that is of type @t@. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> r -> r@ everythingForL' func topTy = do a1 <- argTypeOfName func let op = VarE func g t = do b <- eqType (return a1) t c <- inType (return a1) t if b then return (op, not c) else return (AppE (VarE 'const) (VarE 'id), not c) everythingButAccL' g topTy -- |Generic recursive traversal using right-associative accumulation everythingForR :: Name -- ^ Name of a function of type @t -> r -> r@ -> (Q Type -> Q Exp) -- ^ Generates an 'Exp' that applies the query -- to each decendant that is of type @t@. If -- 'Type' is @t0@, returns an 'Exp' of type @t0 -- -> r -> r@ everythingForR func topTy = do a1 <- argTypeOfName func let op = VarE func g t = do b <- eqType (return a1) t c <- inType (return a1) t if b then return (op, not c) else do h <- [|const id|] return (h, not c) everythingButAccR g topTy -------------------------------------------------------------------------------- -- implementing "mkT" and friends -------------------------------------------------------------------------------- {- In general "mkT" is simply a partially applied "extT" as in the following thus we focus the description on "extT". mkT :: Type -> ExpQ -> (Type -> ExpQ) mkT = extT (const [|id|]) The naive implementation of "extT" simply tests for type equality as in the following. However, this would fail to treat "String" and "[Char]" as equal. extT :: (Type -> ExpQ) -> Type -> ExpQ -> Type -> ExpQ extT def t ext t' | t' == t = ext | otherwise = def t' -} {- Since there are multiple notions of type equality (e.g. by name, exact equality, instantiability, etc.), we do not write an "extT" version for each. Instead we parameterize by the equality predicate. We pair the predicate with the function to use when that predicate is true rather than passing them as separate arguments so that "extT" is easier to use infix. An alternative would be to play with the precedence. -} -- |Returns the type of a variable, method or constructor name. typeOfName :: (Quasi m) => Name -> m Type typeOfName n = do info <- qReify n case info of DataConI _ ty _ _ -> return ty VarI _ ty _ _ -> return ty ClassOpI _ ty _ _ -> return ty _ -> fail $ "'typeOfName' only applies to classes, variables, " ++ "and constructors that are in the current environment, but " ++ "was applied to :" ++ show n -- |Returns the type of the first argument of a variable, method or constructor name. argTypeOfName :: (Quasi m) => Name -> m Type argTypeOfName n = do t <- typeOfName n let getArg (AppT (AppT ArrowT t') _) = return t' getArg (ForallT _ _ t') = getArg t' getArg _ = fail $ "'argTypeOfName' applied to `" ++ show n ++ " which has non-function type `" ++ show t ++ "'." getArg t -- |Extends a generic operation with type specific behavior based on the type of the given name. extN :: (Quasi m) => (Type -> m Exp) -- ^ The operation to be extended. -> Name -- ^ Name of the function implementing the type specific behavior. -> (Type -> m Exp) -- ^ The result of extending the operation. If the 'Name' has type -- @t -> s@, then the extended operation has type specific behavior -- at @t@. At other types it behaves as the original operation. extN def name t0 = extE def (eqType t', return (VarE name)) t0 where t' = do ty <- typeOfName name ty' <- argTypeOfName name when (hasFreeVars ty') $ fail $ "Underspecified type when using extN (or something based on it).\n" ++ " A type variable or forall occurs in the type of the first argument to `" ++ show name ++ "'.\n" ++ " Namely, `" ++ pprint ty' ++ "'\n" ++ " in the type `"++ pprint ty++"'." return ty' hasFreeVars (VarT _) = True hasFreeVars (ForallT _tyVarBndrs _cxt _ty) = True hasFreeVars (ConT _) = False hasFreeVars (TupleT _) = False hasFreeVars (ArrowT) = False hasFreeVars (ListT) = False hasFreeVars (AppT t1 t2) = hasFreeVars t1 || hasFreeVars t2 hasFreeVars (SigT t _kind) = hasFreeVars t -- |Extends a generic operation with type specific behavior. extE :: (Quasi m) => (Type -> m exp) -- ^ The operation to be extended. -> (Type -> m Bool, m exp) -- ^ The 'fst' of the pair should return 'True' on types for which -- the operation should be extended. The 'snd' of the pair is the -- expression to use on those types. -> (Type -> m exp) -- ^ The result of extending the operation. extE def (typePred, ext) t = extE' def (typePred, const ext) t -- |Extends a generic operation with type specific behavior. extE' :: (Quasi m) => (Type -> m exp) -- ^ The operation to be extended. -> (Type -> m Bool, Type -> m exp) -- ^ The 'fst' of the pair should return 'True' on types for which -- the operation should be extended. The 'snd' of the pair when given one of these 'Type's -- should return the expression to use on that type. -> (Type -> m exp) -- ^ The result of extending the operation. extE' def (typePred, ext) t = do test <- typePred t if test then ext t else def t -- |Makes a transformation from a named function. mkT :: (Quasi m) => Name -- ^ Name of a function of type @t -> t@ -> (Type -> m Exp) -- ^ The generic transformation. If the 'Type' -- is @t0@, then the 'Exp' has type @t0 -> t0@. -- The 'Exp' is the named function at @t@, and -- 'id' elsewhere. mkT f = mkTs [f] -- |Makes a transformation from several named functions. mkTs :: (Quasi m) => [Name] -- ^ Names of functions of type @t -> t@ for various @t@. -> (Type -> m Exp) -- ^ The generic transformation. If the 'Type' -- is @t0@, then the 'Exp' has type @t0 -> t0@. -- If one of the named functions matches @t0@, -- then 'Exp' is that function. Otherwise, it is -- 'id'. mkTs = mkXs (return (VarE 'id)) -- |Makes a monadic transformation from a named function. mkM :: (Quasi m) => Name -- ^ Name of a function of type @t -> m t@ -> (Type -> m Exp) -- ^ The generic monadic transformation. If the 'Type' -- is @t0@, then the 'Exp' has type @t0 -> m t0@. -- The 'Exp' is the named function at @t@, and -- 'return' elsewhere. mkM f = mkMs [f] -- |Makes a monadic transformation from several named functions. mkMs :: (Quasi m) => [Name] -- ^ Names of functions of type @t -> m t@ for various @t@. -> (Type -> m Exp) -- ^ The generic monadic transformation. If the 'Type' -- is @t0@, then the 'Exp' has type @t0 -> m t0@. -- If one of the named functions matches @t0@, -- then 'Exp' is that function. Otherwise, it is -- 'return'. mkMs = mkXs (return (VarE 'return)) -- |Makes a query from a named function. mkQ :: (Quasi m) => m Exp -- ^ Default value to return on types other than @t@. -- The 'Exp' must be of type @r@. -> Name -- ^ Name of a function of type @t -> r@ -> (Type -> m Exp) -- ^ The generic transformation. If the 'Type' -- is @t0@, then the 'Exp' has type @t0 -> r@. -- The 'Exp' is the named function at @t@, and -- the provided default value elsewhere. mkQ z f t = mkQs z [f] t -- |Makes a query from several named functions. mkQs :: (Quasi m) => m Exp -- ^ Default value to return on types that do not -- match the @t@ from any named function. The -- 'Exp' must be of type @r@. -> [Name] -- ^ Names of functions of type @t -> r@ for various @t@. -> (Type -> m Exp) -- ^ The generic query. If the 'Type' -- is @t0@, then the 'Exp' has type @t0 -> r@. -- If one of the named functions matches @t0@, -- then 'Exp' is that function. Otherwise, it is -- 'const' of the provided default value. mkQs zM = mkXs (zM >>= \z -> return (AppE (VarE 'const) z)) -- |Makes operations from several named functions. mkXs :: (Quasi m) => m Exp -- ^ Default function for types that do not -- match the @t@ from any named function. The -- 'Exp' must be of type @c t@. -> [Name] -- ^ Names of functions of type @c t@ for various @t@. -> (Type -> m Exp) -- ^ The generic query. If the 'Type' -- is @t0@, then the 'Exp' has type @c t0@. -- If one of the named functions matches @t0@, -- then 'Exp' is that function. Otherwise, it is -- the provided default function. mkXs zM [] = const zM mkXs zM (f:fs) = extN (mkXs zM fs) f {- TODO: -- name vs expr -- for expr: type predicate vs not -- single vs plural -- extF appears pointless though it I think it is actually useful (thus omit from paper?) extF :: ( Type{-|c|-}-> Q Exp{-|c -> h c|-}) -> (Q Type{-|b|-}, Type -> Q Exp{-|b -> h b|-}) -> Type{-|a|-} -> Q Exp{-|a -> h a|-} extNs :: (Type{-|c|-}-> Q Exp{-|c -> h c|-}) -> [Name]{-|b -> h b|-} -> Type{-|a|-}-> Q Exp{-|a -> h a|-} extEs :: (Type -> Q Exp) -> ([Q Type], Q Exp) -> Type -> Q Exp extFs :: (Type -> Q Exp) -> ([Q Type], Type -> Q Exp) -> Type -> Q Exp ext? :: (Quasi m) => (Type -> m exp) -> (Q Type, m exp) -> Type -> m exp ext?s :: (Quasi m) => (Type -> m exp) -> ([Q Type], m exp) -> Type -> m exp ext? :: (Quasi m) => (Type -> m exp) -> (Q Type, Type -> m exp) -> Type -> m exp --TODO: extT' that doesn't take a type in the "ext" extT' :: (Quasi m) => (Type -> m exp) -> (Type -> m Bool, m exp) -> Type -> m exp extT' = extE -} -------------------------------------------------------------------------------- -- Type equality and containment -------------------------------------------------------------------------------- -- |Tests if two types are equal modulo type synonyms and kind -- annotations. Naive equality would fail to equate "String" and -- "[Char]". eqType :: (Quasi m) => m Type -> Type -> m Bool eqType t1 t2 = do t1' <- expandType =<< t1 t2' <- expandType t2 return $ t1' == t2' -- |Test if any of a list of types is equal to a particular type -- modulo type synonyms and kind annotations. Useful when multiple -- types share the same type-specific behavior. eqTypes :: (Quasi m) => [m Type] -> Type -> m Bool eqTypes ts t = liftM or $ mapM (flip eqType t) ts -- |@inType t1 t2 = True@ iff @t1@ is (even recursively) inside @t2@ inType :: (Quasi m) => m Type -> Type -> m Bool inType t1 t2 = evalStateT (flip rec t2 =<< lift (expandType =<< t1)) Set.empty where rec t1 t2 = do t2' <- expandType t2 s <- gets (Set.member t2') if s then return False -- We've already seen and checked this type else if t1 == t2' then return True -- We found a matching type else do modify (Set.insert t2') -- Remember that we were here ctors <- constructorsOf t2' -- We need to recur on the constructors case ctors of Nothing -> return False -- It's a function or primitive type Just ctors' -> check t1 (concatMap snd ctors') -- Now we recur check _t1 [] = return False check t1 (t : ts) = do t' <- rec t1 t if t' then return True else check t1 ts -- |@inTypes ts t2 = True@ iff any of @ts@ is (even recursively) inside @t2@ inTypes :: (Quasi m) => [m Type] -> Type -> m Bool inTypes t1qs t2 = fmap or (mapM (`inType` t2) t1qs) -------------------------------------------------------------------------------- -- Internal utilities -------------------------------------------------------------------------------- ---------------------------------------- -- Constructor queries ---------------------------------------- -- We treat these types as primitive. They were generated from the list of -- instances of the Data.Data class that have unboxed arguments. primTypes :: [Name] primTypes = [ ''Char, ''Double, ''Float, ''Int, ''Int8, ''Int16, ''Int32, ''Int64, ''Integer, ''Word, ''Word8, ''Word16, ''Word32, ''Word64, ''Ptr, ''ForeignPtr, ''Array] -- |Returns the constructors of a given type. -- Returns @Nothing@ if the type is primitive. constructorsOf :: (Quasi m) => Type -> m (Maybe [(Name, [Type])]) constructorsOf t = do c <- getCtors t [] case c of Left _ -> return $ Nothing Right c' -> return $ Just (map f c') where f (name, ts) = (name, map snd ts) -- getCtors is used to implement constructorsOf. -- -- getCtors returns all the information about the constructors of the -- type. In constructorsOf, we limit our selves to the useful -- information. -- -- getCtors must dive into the lhs of type applications app to get -- constructors and then substitute appropriately getCtors :: (Quasi m) => Type -> [Type] -> m (Either (Int, Bool) [(Name, [(Strict, Type)])]) getCtors (ForallT _tyVarBndrs _cxt ty) [] = getCtors ty [] -- needs in-scope variables? getCtors (VarT name) _ = fail $ "constructorsOf: type variable `"++show name++"' in head position of type application" getCtors (ConT name) _args | name `elem` primTypes = return $ Left (0, False) getCtors (ConT name) args = do -- TODO: check arg count t <- qReify name case t of TyConI (DataD _cxt _tyName tyVarBndrs cons _derive) -> return $ Right $ map (doCons tyVarBndrs) cons TyConI (NewtypeD _cxt _tyName tyVarBndrs con _derive) -> return $ Right [doCons tyVarBndrs con] TyConI (TySynD _tyName tyVarBndrs ty) -> getCtors (subst (zip tyVarBndrs args) ty) (drop (length tyVarBndrs) args) PrimTyConI _name arity isUnLifted -> return $ Left (arity, isUnLifted) -- PrimTyConI handles types that can't be expressed with a data type such as (->), Int# _ -> fail $ "Attempt to call `constructorsOf' on non-type: " ++ (show t) where doCons bndrs (NormalC name' ts) = (name', [(strict, subst (zip bndrs args) t) | (strict, t) <- ts]) doCons bndrs (InfixC t1 name' t2) = doCons bndrs (NormalC name' [t1, t2]) doCons bndrs (RecC name' ts) = doCons bndrs (NormalC name' (map (\ (_, s, t) -> (s, t)) ts)) --doCons bndrs (ForallC tyVarBndrs cxt con) = fail $ "GADT constructor found in "++show name++", but GADTs are not (yet) supported" -- TODO: gadt in ForallC -- foo (ClassP name' tys) = ... -- foo (EqP ty1 ty2) -- data Foo a b c = forall d e f. (a ~ b, c ~ Int) => Foo1 a b e f getCtors (TupleT n) args | n == length args = return $ Right [(tupleDataName n, [(NotStrict, t) | t <- args])] getCtors (ArrowT) [_t1, _t2] = return $ Left (2, False) -- constants taken from "reify ''(->)" getCtors (ListT) [t] = return $ Right [('[], []), ('(:), [(NotStrict, t), (NotStrict, AppT ListT t)])] getCtors (AppT t1 t2) args = getCtors t1 (t2 : args) getCtors (SigT t _kind) args = getCtors t args getCtors t args = fail $ "constructorsOf: wrong number of arguments passed to type constructor\n" ++ " Constructor: " ++ show t ++ "\n" ++ " Arguments: " ++ show args ++ "\n" -- Apply a substitution (a.k.a. a list of type-variable bindings) to a type. -- -- Note: Predicates might be able to be simplified as a result of the -- substitution, but this function does not perform that -- simplification. subst :: [(TyVarBndr, Type)] -> Type -> Type subst env0 t0 = subst' [(stripKind bndr, Just t) | (bndr, t) <- env0] t0 where stripKind (PlainTV name) = name stripKind (KindedTV name _kind) = name subst' :: [(Name, Maybe Type)] -> Type -> Type subst' env (ForallT tyVarBndrs cxt t) = ForallT tyVarBndrs (map (substPred env) cxt) (subst' ([(stripKind bndr, Nothing) | bndr <- tyVarBndrs] ++ env) t) subst' env t@(VarT name) | Just (Just t') <- lookup name env = t' | otherwise = t subst' env (AppT t1 t2) = AppT (subst' env t1) (subst' env t2) subst' env (SigT t kind) = SigT (subst' env t) kind subst' _ t@(ConT _) = t subst' _ t@(TupleT _) = t subst' _ t@(ArrowT) = t subst' _ t@(ListT) = t substPred env (ClassP name ts) = ClassP name (map (subst' env) ts) substPred env (EqualP t1 t2) = EqualP (subst' env t1) (subst' env t2) ---------------------------------------- -- Type expansion ---------------------------------------- -- | Expands all type synonyms in its argument. Also strips all kind -- annotations since those might be incorrect once type synonyms are -- applied to their arguments. expandType :: (Quasi m) => Type -> m Type expandType t = expandType' t [] -- most of these clauses can be written cleaner using [t| ... |] quotes -- but then we couldn't use "Quasi m" and would only work on "Q". -- -- This code does not check arity and may produce unexpected results -- in the presence of partially applied type synonyms. expandType' :: (Quasi m) => Type -> [Type] -> m Type expandType' (ForallT tyVarBndrs cxt ty) args = do cxt' <- mapM doCxt cxt ty' <- expandType' ty [] let t' = ForallT tyVarBndrs cxt' ty' return $ foldl AppT t' args where doCxt (ClassP name tys) = liftM (ClassP name) (mapM (flip expandType' []) tys) doCxt (EqualP t1 t2) = liftM2 EqualP (expandType' t1 []) (expandType' t2 []) expandType' (VarT name) args = return $ foldl AppT (VarT name) args expandType' (ConT name) args = do t <- qReify name let nonSynonym = return $ foldl AppT (ConT name) args case t of TyConI (TySynD _name tyVarBndrs ty) -> expandType' (subst (zip tyVarBndrs args) ty) (drop (length tyVarBndrs) args) TyConI (DataD {}) -> nonSynonym TyConI (NewtypeD {}) -> nonSynonym PrimTyConI {} -> nonSynonym expandType' (TupleT n) args = return $ foldl AppT (TupleT n) args expandType' (ArrowT) args = return $ foldl AppT (ArrowT) args expandType' (ListT) args = return $ foldl AppT (ListT) args expandType' (AppT t1 t2) args = do t2' <- expandType' t2 [] expandType' t1 (t2' : args) expandType' (SigT t _kind) args = expandType' t args ---------------------------------------- -- Common error messages ---------------------------------------- -- We don't yet implement tests for polymorphic and quantified types -- for two reasons: -- * First, testing the equivalence of Cxt on a "forall" requires a -- class solver which is beyond the scope of this project -- * Second, even with empty Cxt, it is hard to define exactly what -- as equal "forall"s. For example, are "forall a b. (a, b)" and -- "forall b a. (a, b)" equal? What about "forall a. a -> a" and -- "forall a b. b -> b" or "forall a. Int" and "Int"? Or "forall -- a. ([a], a -> a)" and "forall c d. ([c], d -> d)"? -- SYB also does not have strong support for polymorphic and -- quantified types, so this problem is not unique to this system. {- failForall :: (Monad m) => String -> Type -> m a failForall functionName t = fail $ functionName ++ ": Comparing polymorphic or quantified types is not (yet) implemented." ++ " Found `" ++ pprint t ++ "'." -} ---------------------------------------- -- Function and monadic composition ---------------------------------------- -- | A helper function: @composeE f g = [| $(f) . $(g) |]@. -- Though this works for any 'Quasi' monad and not just 'Q'. composeE :: (Quasi m) => m Exp -> m Exp -> m Exp composeE f1 f2 = do f1' <- f1 f2' <- f2 x <- qNewName "x" return (LamE [VarP x] (AppE f1' (AppE f2' (VarE x)))) --[|\x -> f1 (f2 x)|] -- | A helper function: |@composeM f g = [| $(f) <=< $(g) |]@. -- Though this works for any 'Quasi' monad and not just 'Q'. composeM :: (Quasi m) => m Exp -> m Exp -> m Exp composeM f1 f2 = do f1' <- f1 f2' <- f2 x <- qNewName "x" return (LamE [VarP x] (InfixE (Just (AppE f2' (VarE x))) (VarE '(>>=)) (Just f1'))) --[|\x -> f2 x >>= f1|] -------------------------------------------------------------------------------- -- TODO -------------------------------------------------------------------------------- {- - Support for GADTs - (NOTE) the M of gmapMp becomes an object level monad, but the "p" is meta-level - version of "somewhere" that doesn't recur where useless - version of "somewhere" where (MonadPlus m) => m (Exp, Bool) (True means recur) - note that recurive calls to everything may not terminate for that use "memoizeExp" directly -- TODO: memoizeDecIO :: (a -> b) -> IO (a -> IO b) -- can't be done because of qNewName -- TODO: memoize via thunks -- TODO: explain implementation of memoize instead of just interface -- memoizeDec3, memoizeDec4, etc. -- memoizeExp3, memoizeExp4, etc. -- TODO: generate 2, 3, 4, 5, 6, 7, 8, and 9 via Template Haskell -- TODO: remove "T/M/Q" suffix and replace the "'" suffix with a letter -- TODO: suffixes: by name vs by type predicate; passing type vs not (only for type pred version) -- extByName -- extC -- extWithType -}