{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
module Refinery.ProofState
where
import Control.Applicative
import Control.Monad
import Control.Monad.Catch hiding (handle)
import Control.Monad.Except
import qualified Control.Monad.Writer.Lazy as LW
import qualified Control.Monad.Writer.Strict as SW
import Control.Monad.State
import Control.Monad.Logic
import Control.Monad.Morph
import Control.Monad.Reader
import Data.Either
import GHC.Generics
data ProofStateT ext' ext err s m goal
= Subgoal goal (ext' -> ProofStateT ext' ext err s m goal)
| Effect (m (ProofStateT ext' ext err s m goal))
| Stateful (s -> (s, ProofStateT ext' ext err s m goal))
| Alt (ProofStateT ext' ext err s m goal) (ProofStateT ext' ext err s m goal)
| Interleave (ProofStateT ext' ext err s m goal) (ProofStateT ext' ext err s m goal)
| Commit (ProofStateT ext' ext err s m goal) (ProofStateT ext' ext err s m goal)
| Empty
| Failure err
| Axiom ext
deriving stock (Generic)
instance (Show goal, Show err, Show ext, Show (m (ProofStateT ext' ext err s m goal))) => Show (ProofStateT ext' ext err s m goal) where
show (Subgoal goal _) = "(Subgoal " <> show goal <> " <k>)"
show (Effect m) = "(Effect " <> show m <> ")"
show (Stateful _) = "(Stateful <s>)"
show (Alt p1 p2) = "(Alt " <> show p1 <> " " <> show p2 <> ")"
show (Interleave p1 p2) = "(Interleave " <> show p1 <> " " <> show p2 <> ")"
show (Commit p1 p2) = "(Commit " <> show p1 <> " " <> show p2 <> ")"
show Empty = "Empty"
show (Failure err) = "(Failure " <> show err <> ")"
show (Axiom ext) = "(Axiom " <> show ext <> ")"
instance Functor m => Functor (ProofStateT ext' ext err s m) where
fmap f (Subgoal goal k) = Subgoal (f goal) (fmap f . k)
fmap f (Effect m) = Effect (fmap (fmap f) m)
fmap f (Stateful s) = Stateful $ fmap (fmap f) . s
fmap f (Alt p1 p2) = Alt (fmap f p1) (fmap f p2)
fmap f (Interleave p1 p2) = Interleave (fmap f p1) (fmap f p2)
fmap f (Commit p1 p2) = Commit (fmap f p1) (fmap f p2)
fmap _ Empty = Empty
fmap _ (Failure err) = Failure err
fmap _ (Axiom ext) = Axiom ext
instance Functor m => Applicative (ProofStateT ext ext err s m) where
pure = return
(<*>) = ap
instance MFunctor (ProofStateT ext' ext err s) where
hoist nat (Subgoal a k) = Subgoal a $ fmap (hoist nat) k
hoist nat (Effect m) = Effect $ nat $ fmap (hoist nat) m
hoist nat (Stateful f) = Stateful $ fmap (hoist nat) . f
hoist nat (Alt p1 p2) = Alt (hoist nat p1) (hoist nat p2)
hoist nat (Interleave p1 p2) = Interleave (hoist nat p1) (hoist nat p2)
hoist nat (Commit p1 p2) = Commit (hoist nat p1) (hoist nat p2)
hoist _ (Failure err) = Failure err
hoist _ Empty = Empty
hoist _ (Axiom ext) = Axiom ext
applyCont
:: (Functor m)
=> (ext -> ProofStateT ext' ext err s m a)
-> ProofStateT ext' ext err s m a
-> ProofStateT ext' ext err s m a
applyCont k (Subgoal goal k') = Subgoal goal (applyCont k . k')
applyCont k (Effect m) = Effect (fmap (applyCont k) m)
applyCont k (Stateful s) = Stateful $ fmap (applyCont k) . s
applyCont k (Alt p1 p2) = Alt (applyCont k p1) (applyCont k p2)
applyCont k (Interleave p1 p2) = Interleave (applyCont k p1) (applyCont k p2)
applyCont k (Commit p1 p2) = Commit (applyCont k p1) (applyCont k p2)
applyCont _ Empty = Empty
applyCont _ (Failure err) = (Failure err)
applyCont k (Axiom ext) = k ext
instance Functor m => Monad (ProofStateT ext ext err s m) where
return goal = Subgoal goal Axiom
(Subgoal a k) >>= f = applyCont ((>>= f) . k) (f a)
(Effect m) >>= f = Effect (fmap (>>= f) m)
(Stateful s) >>= f = Stateful $ fmap (>>= f) . s
(Alt p1 p2) >>= f = Alt (p1 >>= f) (p2 >>= f)
(Interleave p1 p2) >>= f = Interleave (p1 >>= f) (p2 >>= f)
(Commit p1 p2) >>= f = Commit (p1 >>= f) (p2 >>= f)
(Failure err) >>= _ = Failure err
Empty >>= _ = Empty
(Axiom ext) >>= _ = Axiom ext
instance MonadTrans (ProofStateT ext ext err s) where
lift m = Effect (fmap pure m)
instance (Monad m) => Alternative (ProofStateT ext ext err s m) where
empty = Empty
(<|>) = Alt
instance (Monad m) => MonadPlus (ProofStateT ext ext err s m) where
mzero = empty
mplus = (<|>)
class (Monad m) => MonadExtract ext m | m -> ext where
hole :: m ext
default hole :: (MonadTrans t, MonadExtract ext m1, m ~ t m1) => m ext
hole = lift hole
instance (MonadExtract ext m) => MonadExtract ext (ReaderT r m)
instance (MonadExtract ext m) => MonadExtract ext (StateT s m)
instance (MonadExtract ext m, Monoid w) => MonadExtract ext (LW.WriterT w m)
instance (MonadExtract ext m, Monoid w) => MonadExtract ext (SW.WriterT w m)
instance (MonadExtract ext m) => MonadExtract ext (ExceptT err m)
proofs :: forall ext err s m goal. (MonadExtract ext m) => s -> ProofStateT ext ext err s m goal -> m [Either err (ext, s, [goal])]
proofs s p = go s [] p
where
go s goals (Subgoal goal k) = do
h <- hole
(go s (goals ++ [goal]) $ k h)
go s goals (Effect m) = go s goals =<< m
go s goals (Stateful f) =
let (s', p) = f s
in go s' goals p
go s goals (Alt p1 p2) = liftA2 (<>) (go s goals p1) (go s goals p2)
go s goals (Interleave p1 p2) = liftA2 (interleave) (go s goals p1) (go s goals p2)
go s goals (Commit p1 p2) = go s goals p1 >>= \case
(rights -> []) -> go s goals p2
solns -> pure solns
go _ _ Empty = pure []
go _ _ (Failure err) = pure [throwError err]
go s goals (Axiom ext) = pure [Right (ext, s, goals)]
accumEither :: (Semigroup a, Semigroup b) => Either a b -> Either a b -> Either a b
accumEither (Left a1) (Left a2) = Left (a1 <> a2)
accumEither (Right b1) (Right b2) = Right (b1 <> b2)
accumEither Left{} x = x
accumEither x Left{} = x
instance (MonadIO m) => MonadIO (ProofStateT ext ext err s m) where
liftIO = lift . liftIO
instance (MonadThrow m) => MonadThrow (ProofStateT ext ext err s m) where
throwM = lift . throwM
instance (MonadCatch m) => MonadCatch (ProofStateT ext ext err s m) where
catch (Subgoal goal k) handle = Subgoal goal (flip catch handle . k)
catch (Effect m) handle = Effect . catch m $ pure . handle
catch (Stateful s) handle = Stateful (fmap (flip catch handle) . s)
catch (Alt p1 p2) handle = Alt (catch p1 handle) (catch p2 handle)
catch (Interleave p1 p2) handle = Interleave (catch p1 handle) (catch p2 handle)
catch (Commit p1 p2) handle = Commit (catch p1 handle) (catch p2 handle)
catch Empty _ = Empty
catch (Failure err) _ = Failure err
catch (Axiom e) _ = (Axiom e)
instance (Monad m) => MonadError err (ProofStateT ext ext err s m) where
throwError = Failure
catchError (Subgoal goal k) handle = Subgoal goal (flip catchError handle . k)
catchError (Effect m) handle = Effect (fmap (flip catchError handle) m)
catchError (Stateful s) handle = Stateful $ fmap (flip catchError handle) . s
catchError (Alt p1 p2) handle = catchError p1 handle <|> catchError p2 handle
catchError (Interleave p1 p2) handle = Interleave (catchError p1 handle) (catchError p2 handle)
catchError (Commit p1 p2) handle = catchError p1 handle <|> catchError p2 handle
catchError Empty _ = Empty
catchError (Failure err) handle = handle err
catchError (Axiom e) _ = (Axiom e)
instance (MonadReader r m) => MonadReader r (ProofStateT ext ext err s m) where
ask = lift ask
local f (Subgoal goal k) = Subgoal goal (local f . k)
local f (Effect m) = Effect (local f m)
local f (Stateful s) = Stateful (fmap (local f) . s)
local f (Alt p1 p2) = Alt (local f p1) (local f p2)
local f (Interleave p1 p2) = Interleave (local f p1) (local f p2)
local f (Commit p1 p2) = Commit (local f p1) (local f p2)
local _ Empty = Empty
local _ (Failure err) = (Failure err)
local _ (Axiom e) = (Axiom e)
instance (Monad m) => MonadState s (ProofStateT ext ext err s m) where
state f = Stateful $ \s ->
let (a, s') = f s
in (s', pure a)
axiom :: ext -> ProofStateT ext' ext err s m jdg
axiom = Axiom
subgoals :: (Functor m) => [jdg -> ProofStateT ext ext err s m jdg] -> ProofStateT ext ext err s m jdg -> ProofStateT ext ext err s m jdg
subgoals [] (Subgoal goal k) = applyCont k (pure goal)
subgoals (f:fs) (Subgoal goal k) = applyCont (subgoals fs . k) (f goal)
subgoals fs (Effect m) = Effect (fmap (subgoals fs) m)
subgoals fs (Stateful s) = Stateful (fmap (subgoals fs) . s)
subgoals fs (Alt p1 p2) = Alt (subgoals fs p1) (subgoals fs p2)
subgoals fs (Interleave p1 p2) = Interleave (subgoals fs p1) (subgoals fs p2)
subgoals fs (Commit p1 p2) = Commit (subgoals fs p1) (subgoals fs p2)
subgoals _ (Failure err) = Failure err
subgoals _ Empty = Empty
subgoals _ (Axiom ext) = Axiom ext
mapExtract :: (Functor m) => (ext -> ext') -> (ext' -> ext) -> ProofStateT ext ext err s m jdg -> ProofStateT ext' ext' err s m jdg
mapExtract into out = \case
Subgoal goal k -> Subgoal goal $ mapExtract into out . k . out
Effect m -> Effect (fmap (mapExtract into out) m)
Stateful s -> Stateful (fmap (mapExtract into out) . s)
Alt t1 t2 -> Alt (mapExtract into out t1) (mapExtract into out t2)
Interleave t1 t2 -> Interleave (mapExtract into out t1) (mapExtract into out t2)
Commit t1 t2 -> Commit (mapExtract into out t1) (mapExtract into out t2)
Empty -> Empty
Failure err -> Failure err
Axiom ext -> Axiom $ into ext
mapExtract' :: Functor m => (a -> b) -> ProofStateT ext' a err s m jdg -> ProofStateT ext' b err s m jdg
mapExtract' into = \case
Subgoal goal k -> Subgoal goal $ mapExtract' into . k
Effect m -> Effect (fmap (mapExtract' into) m)
Stateful s -> Stateful (fmap (mapExtract' into) . s)
Alt t1 t2 -> Alt (mapExtract' into t1) (mapExtract' into t2)
Interleave t1 t2 -> Interleave (mapExtract' into t1) (mapExtract' into t2)
Commit t1 t2 -> Commit (mapExtract' into t1) (mapExtract' into t2)
Empty -> Empty
Failure err -> Failure err
Axiom ext -> Axiom $ into ext