{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} module HERMIT.Plugin.Types where import Control.Concurrent.STM import Control.Monad.Error.Class (MonadError(..)) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Reader (MonadReader(..), ReaderT(..)) import Control.Monad.State (MonadState(..), StateT(..)) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Except (ExceptT, runExceptT) import Data.Dynamic import qualified Data.Map as M import HERMIT.Kure import HERMIT.External import HERMIT.Kernel import HERMIT.Monad import HERMIT.Plugin.Builder import HERMIT.PrettyPrinter.Common import HERMIT.Dictionary.Reasoning import Prelude.Compat import System.IO type PluginM = PluginT IO newtype PluginT m a = PluginT { unPluginT :: ExceptT PException (ReaderT PluginReader (StateT PluginState m)) a } deriving (Functor, Applicative, MonadIO, MonadError PException, MonadState PluginState, MonadReader PluginReader) runPluginT :: PluginReader -> PluginState -> PluginT m a -> m (Either PException a, PluginState) runPluginT pr ps = flip runStateT ps . flip runReaderT pr . runExceptT . unPluginT instance Monad m => Monad (PluginT m) where return = PluginT . return PluginT m >>= k = PluginT (m >>= unPluginT . k) fail = PluginT . throwError . PError instance MonadTrans PluginT where lift = PluginT . lift . lift . lift instance Monad m => MonadCatch (PluginT m) where -- law: fail msg `catchM` f == f msg -- catchM :: m a -> (String -> m a) -> m a catchM m f = do st <- get r <- ask (er,st') <- lift $ runPluginT r st m case er of Left err -> case err of PError msg -> f msg other -> throwError other -- rethrow abort/resume Right v -> put st' >> return v -- Treat current AST as state, allow pretty-printer to be modified, core lint to be auto-run data PluginState = PluginState { ps_cursor :: AST -- ^ the current AST , ps_pretty :: PrettyPrinter -- ^ which pretty printer to use , ps_render :: Handle -> PrettyOptions -> Either String DocH -> IO () -- ^ the way of outputing to the screen , ps_tick :: TVar (M.Map String Int) -- ^ the list of ticked messages , ps_corelint :: Bool -- ^ if true, run Core Lint on module after each rewrite } deriving (Typeable) data PluginReader = PluginReader { pr_kernel :: Kernel , pr_pass :: PassInfo } deriving (Typeable) data PException = PAbort | PResume AST | PError String newtype PSBox = PSBox PluginState deriving Typeable instance Extern PluginState where type Box PluginState = PSBox unbox (PSBox st) = st box = PSBox -- tick counter tick :: TVar (M.Map String Int) -> String -> IO Int tick var msg = atomically $ do m <- readTVar var let c = case M.lookup msg m of Nothing -> 1 Just x -> x + 1 writeTVar var (M.insert msg c m) return c mkKernelEnv :: PluginState -> KernelEnv mkKernelEnv st = let pp = ps_pretty st out str = liftIO $ ps_render st stdout (pOptions pp) (Left $ str ++ "\n") in KernelEnv $ \ msg -> case msg of DebugTick msg' -> do c <- liftIO $ tick (ps_tick st) msg' out $ "<" ++ show c ++ "> " ++ msg' DebugCore msg' cxt qc -> do out $ "[" ++ msg' ++ "]" doc :: DocH <- applyT (ppLCoreTCT pp) (liftPrettyC (pOptions pp) cxt) qc liftIO $ ps_render st stdout (pOptions pp) (Right doc) AddObligation _ nm l -> insertLemma nm l