{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Refinery.Tactic.Internal
( TacticT(..)
, tactic
, proofState
, mapTacticT
, MonadRule(..)
, RuleT(..)
)
where
import GHC.Generics
import Control.Applicative
import Control.Monad.Identity
import Control.Monad.Except
import Control.Monad.Catch
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Trans ()
import Control.Monad.IO.Class ()
import Control.Monad.Morph
import Data.Coerce
import Refinery.ProofState
newtype TacticT jdg ext err s m a = TacticT { unTacticT :: StateT jdg (ProofStateT ext ext err s m) a }
deriving ( Functor
, Applicative
, Alternative
, Monad
, MonadPlus
, MonadReader env
, MonadError err
, MonadIO
, MonadThrow
, MonadCatch
, Generic
)
instance (Monoid jdg, Show a, Show jdg, Show err, Show ext, Show (m (ProofStateT ext ext err s m (a, jdg)))) => Show (TacticT jdg ext err s m a) where
show = show . flip runStateT mempty . unTacticT
tactic :: (jdg -> ProofStateT ext ext err s m (a, jdg)) -> TacticT jdg ext err s m a
tactic t = TacticT $ StateT t
proofState :: TacticT jdg ext err s m a -> jdg -> ProofStateT ext ext err s m (a, jdg)
proofState t j = runStateT (unTacticT t) j
mapTacticT :: (Monad m) => (m a -> m b) -> TacticT jdg ext err s m a -> TacticT jdg ext err s m b
mapTacticT f (TacticT m) = TacticT $ m >>= (lift . lift . f . return)
instance MFunctor (TacticT jdg ext err s) where
hoist f = TacticT . (hoist (hoist f)) . unTacticT
instance MonadTrans (TacticT jdg ext err s) where
lift m = TacticT $ lift $ lift m
instance (Monad m) => MonadState s (TacticT jdg ext err s m) where
state f = tactic $ \j -> fmap (,j) $ state f
newtype RuleT jdg ext err s m a = RuleT
{ unRuleT :: ProofStateT ext a err s m jdg
}
deriving stock Generic
instance (Show jdg, Show err, Show a, Show (m (ProofStateT ext a err s m jdg))) => Show (RuleT jdg ext err s m a) where
show = show . unRuleT
instance Functor m => Functor (RuleT jdg ext err s m) where
fmap = coerce mapExtract'
instance Monad m => Applicative (RuleT jdg ext err s m) where
pure = return
(<*>) = ap
instance Monad m => Monad (RuleT jdg ext err s m) where
return = coerce . Axiom
RuleT (Subgoal goal k) >>= f = coerce $ Subgoal goal $ fmap (bindAlaCoerce f) k
RuleT (Effect m) >>= f = coerce $ Effect $ fmap (bindAlaCoerce f) m
RuleT (Stateful s) >>= f = coerce $ Stateful $ fmap (bindAlaCoerce f) . s
RuleT (Alt p1 p2) >>= f = coerce $ Alt (bindAlaCoerce f p1) (bindAlaCoerce f p2)
RuleT (Interleave p1 p2) >>= f = coerce $ Interleave (bindAlaCoerce f p1) (bindAlaCoerce f p2)
RuleT (Commit p1 p2) >>= f = coerce $ Commit (bindAlaCoerce f p1) (bindAlaCoerce f p2)
RuleT Empty >>= _ = coerce $ Empty
RuleT (Failure err) >>= _ = coerce $ Failure err
RuleT (Axiom e) >>= f = f e
instance Monad m => MonadState s (RuleT jdg ext err s m) where
state f = RuleT $ Stateful $ \s ->
let (a, s') = f s
in (s', Axiom a)
instance MonadReader r m => MonadReader r (RuleT jdg ext err s m) where
ask = lift ask
local f (RuleT (Subgoal goal k)) = coerce $ Subgoal goal (localAlaCoerce f . k)
local f (RuleT (Effect m)) = coerce $ Effect (local f m)
local f (RuleT (Stateful s)) = coerce $ Stateful (fmap (localAlaCoerce f) . s)
local f (RuleT (Alt p1 p2)) = coerce $ Alt (localAlaCoerce f p1) (localAlaCoerce f p2)
local f (RuleT (Interleave p1 p2)) = coerce $ Interleave (localAlaCoerce f p1) (localAlaCoerce f p2)
local f (RuleT (Commit p1 p2)) = coerce $ Commit (localAlaCoerce f p1) (localAlaCoerce f p2)
local _ (RuleT Empty) = coerce $ Empty
local _ (RuleT (Failure err)) = coerce $ Failure err
local _ (RuleT (Axiom e)) = coerce $ Axiom e
bindAlaCoerce
:: (Monad m, Coercible c (m b), Coercible a1 (m a2)) =>
(a2 -> m b) -> a1 -> c
bindAlaCoerce f = coerce . (f =<<) . coerce
localAlaCoerce
:: (MonadReader r m) =>
(r -> r) -> ProofStateT ext a err s m jdg -> ProofStateT ext a err s m jdg
localAlaCoerce f = coerce . local f . RuleT
instance MonadTrans (RuleT jdg ext err s) where
lift = coerce . Effect . fmap Axiom
instance MFunctor (RuleT jdg ext err s) where
hoist nat = hoist nat . coerce
instance MonadIO m => MonadIO (RuleT jdg ext err s m) where
liftIO = lift . liftIO
instance Monad m => MonadError err (RuleT jdg ext err s m) where
throwError = coerce . Failure
catchError r h = coerce $ flip catchError h $ coerce r
class (Monad m) => MonadRule jdg ext m | m -> jdg, m -> ext where
subgoal :: jdg -> m ext
default subgoal :: (MonadTrans t, MonadRule jdg ext m1, m ~ t m1) => jdg -> m ext
subgoal = lift . subgoal
instance (Monad m) => MonadRule jdg ext (RuleT jdg ext err s m) where
subgoal j = RuleT $ Subgoal j Axiom
instance (MonadRule jdg ext m) => MonadRule jdg ext (ReaderT env m)
instance (MonadRule jdg ext m) => MonadRule jdg ext (StateT env m)
instance (MonadRule jdg ext m) => MonadRule jdg ext (ExceptT env m)