{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE TypeFamilies #-} module Type.InstanceMap.TH ( mkMap, mkMapWithOpts, defaultOptions, ClassName, InputTypeName, OutputWrapperName, Some(..), Options(..)) where import Data.Typeable import qualified Data.Set as S import Data.List (nub, intersect) import qualified Data.Map.Strict as M import Control.Monad.State.Strict (StateT(..), put, get, evalStateT, modify, withStateT) import Control.Monad.Trans (lift) import Control.Monad.IO.Class (liftIO) import Data.Traversable (for) import qualified Debug.Trace as DBG import Control.Monad (join, foldM, when) import Language.Haskell.TH import GHC.Exts data family Some (c :: * -> Constraint) data Options = Options { -- | How deep to traverse the tree of instance contexts when attempting to find -- monomorphic types that are members of the class. Default is 2. -- -- For instance, if we have declarations like: -- -- > data NumWithS n = NumWithS String n -- > instance (Num a) => Num (NumWithS a) -- -- with maxDepth set to 2, then NumWithS Int and NumWithS (NumWithS Int) would -- both be found but NumWithS (NumWithS (NumWithS Int)) would not be. maxDepth :: Int, -- | Whether to produce lots of debugging output. Default False. verbose :: Bool, witnessGenerator :: ExpQ, witnessTypeName :: Name } defaultOptions = Options 2 False [|typeRep|] ''TypeRep type ConcreteType = Type type InstanceHead = Type type ClassName = Name type InputTypeName = Name type OutputWrapperName = Name data InstanceTraversalState = InstanceTraversalState { classesInProgress :: [Name], classesDone :: M.Map Name [ConcreteType] } initialTraversalState = InstanceTraversalState [] M.empty type StateQ a = StateT InstanceTraversalState Q a -- get a list of instances getInstances :: Name -> Q [InstanceDec] getInstances typ = do ClassI _ instances <- reify typ return instances showInstances :: Name -> Q Exp showInstances typ = do ins <- getInstances typ return . LitE . stringL $ show ins -- | When used as a declaration splice, this function will create three declarations: -- -- * A data instance 'Some' /ClassName/ with a single constructor Some/ClassName/ -- -- * A value named mapOf/ClassName/, which has type -- -- @ -- 'M.Map' 'TypeRep' /InputType/ -> 'Some' /Class/ -- @ -- -- * A helper function getSome/ClassName/ with signature -- -- @ -- 'TypeRep' -> /InputType/ -> 'Maybe' ('Some' /Class/) -- @ -- -- The four arguments to this function are: -- -- 1. A TypeRep indicating what monomorphic type we should try to decode -- -- 2. The name of the type that is the argument to the decoder function. -- Often ''ByteString or ''Aeson.Value. -- -- 3. The name of a monad type with an instance of MonadFail that the functional -- expression uses to wrap the return value. -- Often ''Maybe. -- -- 4. A quasiquoted expression for a function that takes an argument of the type -- named by (2) and has a polymorphic return type wrapped in the monad (3). -- -- For example, for JSON the call generally be: -- -- > $(mkMap ''MyClass ''Value ''Result [|fromJSON|]) mkMap :: ClassName -> InputTypeName -> OutputWrapperName -> ExpQ -> Q [Dec] mkMap = mkMapWithOpts defaultOptions -- | Like 'mkMap' but with user-provided 'Options' mkMapWithOpts :: Options -> ClassName -> InputTypeName -> OutputWrapperName -> ExpQ -> Q [Dec] mkMapWithOpts opts className inType outWrap fExp = do typs <- knownConcreteInstances opts className lst <- mapM mkExp typs let witnesses = traverse mkWitness typs typMap <- [|M.fromList $ zip ($(ListE <$> witnesses)) $(return $ ListE lst)|] decoderE <- [| case M.lookup trep $(varE mapName) of Nothing -> fail $ "No instance found for " ++ show trep Just f -> f v |] a <- newName "a" let dataInstD = DataInstD [] ''Some [ConT className] Nothing [ForallC [PlainTV a] [AppT (ConT ''Typeable) (VarT a),AppT (ConT className) (VarT a)] (NormalC someName [(Bang NoSourceUnpackedness NoSourceStrictness, (VarT a))])] [] mapDefinitionD = ValD (VarP mapName) (NormalB typMap) [] decoderSigD = SigD getterName (AppT (AppT ArrowT (ConT (witnessTypeName opts))) (AppT (AppT ArrowT (ConT inType)) (AppT (ConT outWrap) (AppT (ConT ''Some) (ConT className))))) decoderD = FunD getterName [Clause [VarP (mkName "trep"), VarP (mkName "v")] (NormalB decoderE) []] return $ [dataInstD, mapDefinitionD, decoderSigD, decoderD] where mkExp :: Type -> Q Exp mkExp t = [| fmap $(conE someName) . ($(fExp) :: $(conT inType) -> $(appT (conT outWrap) (return t))) |] someName = mkName ("Some" ++ nameBase className) mapName = mkName ("mapOf" ++ nameBase className) getterName = mkName ("getSome" ++ nameBase className) mkWitness :: ConcreteType -> ExpQ mkWitness t = [|$(witnessGenerator opts) (Proxy :: Proxy $(return t))|] -- Get all the variables in constraints; get the instances for each constraint; intersect -- over the constraints; replace variables in the instance type knownConcreteInstances :: Options -> Name -> Q [ConcreteType] knownConcreteInstances opts className = evalStateT (knownConcreteInstances' opts className) initialTraversalState knownConcreteInstances' :: Options -> Name -> StateQ [ConcreteType] knownConcreteInstances' opts className = do InstanceTraversalState {..} <- get -- warn ("Depth " ++ show (S.size classesInProgress) ++ "; stack is " ++ show (S.toList classesInProgress)) case M.lookup className classesDone of Just types -> return types Nothing -> if length classesInProgress > maxDepth opts then warn opts ("Cutting off recursion at " ++ show className) >> return [] else do info opts $ "Looking for instances of " ++ show className trInsts <- lift $ getInstances className :: StateQ [InstanceDec] concreteInsts <- for trInsts $ \(InstanceD _ ctx (AppT _ head) _) -> do s <- get lift $ evalStateT (deepReplaceVars opts head ctx) (s { classesInProgress = className : classesInProgress }) let retVal = join concreteInsts info opts $ "Returning instances for " ++ show className ++ ": " ++ show retVal modify $ \s -> s { classesDone = M.insert className retVal classesDone } return retVal -- Flatten the list of lists -- Ctx is a list of univariate constraints. Return a list of the types we can find that -- satisfy those constraints -- If it's a simple declaration like "instance Eq MyType" then we're done -- but if there is a context with variables e.g. instance (XYZ a) => Eq a -- then we need to recurse and find instances of XYZ and replace the a deepReplaceVars :: Options -> InstanceHead -> Cxt -> StateQ [ConcreteType] deepReplaceVars opts t constraints | monomorphic t = return [t] | otherwise = do InstanceTraversalState {..} <- get -- Get type variables and their candidate types for each constraint if (not (all univariate constraints)) then do warn opts $ "Only simple univariate constraints (like 'C a') are supported (skipping " ++ show t ++ " due to " ++ show (filter univariate constraints) ++ ")" return [] else do constraintCandidates <- traverse getVarsAndCandidates constraints :: StateQ [(Name, [ConcreteType])] -- Group by variable let mc :: M.Map Name [[Type]] mc = M.fromListWith (flip mappend) $ fmap (\(n, typs) -> (n, [typs])) constraintCandidates -- Possible values for a variable are those that satisfy every constraint possibleVals = foldl1 intersect <$> mc return $ allReplacements possibleVals t where univariate (AppT (ConT cls) (VarT v)) = True univariate _ = False getVarsAndCandidates :: Type -> StateQ (Name, [ConcreteType]) getVarsAndCandidates (AppT (ConT cls) (VarT v)) = (v,) <$> knownConcreteInstances' opts cls getVarsAndCandidates t = lift $ (,[]) <$> newName ""-- error ("Unexpected constraint in instance context: " ++ show t) warn opts s = if verbose opts then lift $ reportWarning s else return () info opts s = if verbose opts then lift $ runIO $ putStrLn s else return () allReplacements :: M.Map Name [ConcreteType] -> InstanceHead -> [ConcreteType] allReplacements var2candidates instHead = foldM substitute instHead (assocs var2candidates) where assocs hm = zip (M.keys hm) (M.elems hm) substitute :: Type -> (Name, [ConcreteType]) -> [Type] substitute (VarT v1) (v2, substs) | v1 == v2 = substs | otherwise = [VarT v1] substitute (AppT t1 t2) s = AppT <$> (substitute t1 s) <*> (substitute t2 s) substitute (SigT t1 k) s = SigT <$> (substitute t1 s) <*> [k] substitute (InfixT t1 n t2) s = InfixT <$> (substitute t1 s) <*> [n] <*> (substitute t2 s) substitute (UInfixT t1 n t2) s = UInfixT <$> (substitute t1 s) <*> [n] <*> (substitute t2 s) substitute (ParensT t1) s = ParensT <$> (substitute t1 s) substitute t _ = [t] monomorphic :: Type -> Bool monomorphic (VarT v) = False monomorphic (AppT t1 t2) = monomorphic t1 && monomorphic t2 monomorphic (SigT t1 _) = monomorphic t1 monomorphic (InfixT t1 n t2) = monomorphic t1 && monomorphic t2 monomorphic (UInfixT t1 n t2) = monomorphic t1 && monomorphic t2 monomorphic (ParensT t1) = monomorphic t1 monomorphic _ = True intersectInstances :: [Name] -> Q [InstanceDec] intersectInstances classNames = foldl1 intersect <$> (traverse getInstances classNames) -- indexMapRegistry = fromList [ -- ("Key", fromList [ -- ("Security", fromJSON :: Result (Key Security)), -- ("Snapshot", fromJSON :: Result (Key Snapshot)), -- ...