{-# LANGUAGE TupleSections #-}
{-# LANGUAGE MonoLocalBinds #-}
{-|
   Monadic interface to e-graph stateful computations
 -}
module Data.Equality.Graph.Monad
  (
    -- * Threading e-graphs in a stateful computation
    --
    -- | These are the same operations over e-graphs as in 'Data.Equality.Graph',
    -- but defined in the context of a 'State' monad threading around the e-graph.
    egraph
  , represent
  , add
  , merge
  , rebuild
  , EG.canonicalize
  , EG.find
  , EG.emptyEGraph

    -- * E-graph transformations for monadic analysis
    --
    -- | The same e-graph operations in a stateful computation threading around
    -- the e-graph, but for 'Analysis' defined monadically ('AnalysisM').
  , representM, addM, mergeM, rebuildM

  -- * E-graph stateful computations
  , EGraphM
  , EGraphMT
  , runEGraphM
  , runEGraphMT

  -- * E-graph definition re-export
  , EG.EGraph

  -- * 'State' monad re-exports
  , modify, get, gets
  ) where

import Control.Monad ((>=>))
import Control.Monad.Trans.State.Strict

import Data.Equality.Utils (Fix, cata)

import Data.Equality.Analysis
import qualified Data.Equality.Analysis.Monadic as AM
import Data.Equality.Graph (EGraph, ClassId, Language, ENode(..))
import qualified Data.Equality.Graph as EG

-- | E-graph stateful computation
type EGraphM a l = State (EGraph a l)
-- | E-graph stateful computation over an arbitrary monad
type EGraphMT a l = StateT (EGraph a l)

-- | Run EGraph computation on an empty e-graph
--
-- === Example
-- @
-- egraph $ do
--  id1 <- represent t1
--  id2 <- represent t2
--  merge id1 id2
-- @
egraph :: Language l => EGraphM anl l a -> (a, EGraph anl l)
egraph :: forall (l :: * -> *) anl a.
Language l =>
EGraphM anl l a -> (a, EGraph anl l)
egraph = EGraph anl l -> EGraphM anl l a -> (a, EGraph anl l)
forall anl (l :: * -> *) a.
EGraph anl l -> EGraphM anl l a -> (a, EGraph anl l)
runEGraphM EGraph anl l
forall (l :: * -> *) a. Language l => EGraph a l
EG.emptyEGraph
{-# INLINE egraph #-}

-- | Represent an expression (@Fix l@) in an e-graph by recursively
-- representing sub expressions
represent :: (Analysis anl l, Language l) => Fix l -> EGraphM anl l ClassId
represent :: forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
Fix l -> EGraphM anl l ClassId
represent = (l (EGraphM anl l ClassId) -> EGraphM anl l ClassId)
-> Fix l -> EGraphM anl l ClassId
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata ((l (EGraphM anl l ClassId) -> EGraphM anl l ClassId)
 -> Fix l -> EGraphM anl l ClassId)
-> (l (EGraphM anl l ClassId) -> EGraphM anl l ClassId)
-> Fix l
-> EGraphM anl l ClassId
forall a b. (a -> b) -> a -> b
$ l (EGraphM anl l ClassId)
-> StateT (EGraph anl l) Identity (l ClassId)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => l (m a) -> m (l a)
sequence (l (EGraphM anl l ClassId)
 -> StateT (EGraph anl l) Identity (l ClassId))
-> (l ClassId -> EGraphM anl l ClassId)
-> l (EGraphM anl l ClassId)
-> EGraphM anl l ClassId
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> ENode l -> EGraphM anl l ClassId
forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
ENode l -> EGraphM anl l ClassId
add (ENode l -> EGraphM anl l ClassId)
-> (l ClassId -> ENode l) -> l ClassId -> EGraphM anl l ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l ClassId -> ENode l
forall (l :: * -> *). l ClassId -> ENode l
Node
{-# INLINE represent #-}

-- | Add an e-node to the e-graph
add :: (Analysis anl l, Language l) => ENode l -> EGraphM anl l ClassId
add :: forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
ENode l -> EGraphM anl l ClassId
add = (EGraph anl l -> Identity (ClassId, EGraph anl l))
-> StateT (EGraph anl l) Identity ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph anl l -> Identity (ClassId, EGraph anl l))
 -> StateT (EGraph anl l) Identity ClassId)
-> (ENode l -> EGraph anl l -> Identity (ClassId, EGraph anl l))
-> ENode l
-> StateT (EGraph anl l) Identity ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ClassId, EGraph anl l) -> Identity (ClassId, EGraph anl l))
-> (EGraph anl l -> (ClassId, EGraph anl l))
-> EGraph anl l
-> Identity (ClassId, EGraph anl l)
forall a b. (a -> b) -> (EGraph anl l -> a) -> EGraph anl l -> b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ClassId, EGraph anl l) -> Identity (ClassId, EGraph anl l)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((EGraph anl l -> (ClassId, EGraph anl l))
 -> EGraph anl l -> Identity (ClassId, EGraph anl l))
-> (ENode l -> EGraph anl l -> (ClassId, EGraph anl l))
-> ENode l
-> EGraph anl l
-> Identity (ClassId, EGraph anl l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode l -> EGraph anl l -> (ClassId, EGraph anl l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ENode l -> EGraph a l -> (ClassId, EGraph a l)
EG.add
{-# INLINE add #-}

-- | Merge two e-classes by id
--
-- E-graph invariants may be broken by merging, and 'rebuild' should be used
-- /eventually/ to restore them
merge :: (Analysis anl l, Language l) => ClassId -> ClassId -> EGraphM anl l ClassId
merge :: forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
ClassId -> ClassId -> EGraphM anl l ClassId
merge ClassId
a ClassId
b = (EGraph anl l -> Identity (ClassId, EGraph anl l))
-> StateT (EGraph anl l) Identity ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((ClassId, EGraph anl l) -> Identity (ClassId, EGraph anl l)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((ClassId, EGraph anl l) -> Identity (ClassId, EGraph anl l))
-> (EGraph anl l -> (ClassId, EGraph anl l))
-> EGraph anl l
-> Identity (ClassId, EGraph anl l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ClassId -> ClassId -> EGraph anl l -> (ClassId, EGraph anl l)
forall a (l :: * -> *).
(Analysis a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> (ClassId, EGraph a l)
EG.merge ClassId
a ClassId
b)
{-# INLINE merge #-}

-- | Rebuild: Restore e-graph invariants
--
-- E-graph invariants are traditionally maintained after every merge, but we
-- allow operations to temporarilly break the invariants (specifically, until we call
-- 'rebuild')
--
-- The paper describing rebuilding in detail is https://arxiv.org/abs/2004.03082
rebuild :: (Analysis anl l, Language l) => EGraphM anl l ()
rebuild :: forall anl (l :: * -> *).
(Analysis anl l, Language l) =>
EGraphM anl l ()
rebuild = (EGraph anl l -> Identity ((), EGraph anl l))
-> StateT (EGraph anl l) Identity ()
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT (((), EGraph anl l) -> Identity ((), EGraph anl l)
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (((), EGraph anl l) -> Identity ((), EGraph anl l))
-> (EGraph anl l -> ((), EGraph anl l))
-> EGraph anl l
-> Identity ((), EGraph anl l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((),)(EGraph anl l -> ((), EGraph anl l))
-> (EGraph anl l -> EGraph anl l)
-> EGraph anl l
-> ((), EGraph anl l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph anl l -> EGraph anl l
forall a (l :: * -> *).
(Analysis a l, Language l) =>
EGraph a l -> EGraph a l
EG.rebuild)
{-# INLINE rebuild #-}

-- | Run 'EGraphM' computation on a given e-graph
runEGraphM :: EGraph anl l -> EGraphM anl l a -> (a, EGraph anl l)
runEGraphM :: forall anl (l :: * -> *) a.
EGraph anl l -> EGraphM anl l a -> (a, EGraph anl l)
runEGraphM = (EGraphM anl l a -> EGraph anl l -> (a, EGraph anl l))
-> EGraph anl l -> EGraphM anl l a -> (a, EGraph anl l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip EGraphM anl l a -> EGraph anl l -> (a, EGraph anl l)
forall s a. State s a -> s -> (a, s)
runState
{-# INLINE runEGraphM #-}

--------------------------------------------------------------------------------
-- Monadic Analysis interface

-- | Run 'EGraphM' computation on a given e-graph over a monadic analysis
runEGraphMT :: EGraph anl l -> EGraphMT anl l m a -> m (a, EGraph anl l)
runEGraphMT :: forall anl (l :: * -> *) (m :: * -> *) a.
EGraph anl l -> EGraphMT anl l m a -> m (a, EGraph anl l)
runEGraphMT = (EGraphMT anl l m a -> EGraph anl l -> m (a, EGraph anl l))
-> EGraph anl l -> EGraphMT anl l m a -> m (a, EGraph anl l)
forall a b c. (a -> b -> c) -> b -> a -> c
flip EGraphMT anl l m a -> EGraph anl l -> m (a, EGraph anl l)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT
{-# INLINE runEGraphMT #-}

-- | Like 'represent', but for a monadic analysis
representM :: (AM.AnalysisM m anl l, Language l) => Fix l -> EGraphMT anl l m ClassId
representM :: forall (m :: * -> *) anl (l :: * -> *).
(AnalysisM m anl l, Language l) =>
Fix l -> EGraphMT anl l m ClassId
representM = (EGraph anl l -> m (ClassId, EGraph anl l))
-> StateT (EGraph anl l) m ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph anl l -> m (ClassId, EGraph anl l))
 -> StateT (EGraph anl l) m ClassId)
-> (Fix l -> EGraph anl l -> m (ClassId, EGraph anl l))
-> Fix l
-> StateT (EGraph anl l) m ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fix l -> EGraph anl l -> m (ClassId, EGraph anl l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
Fix l -> EGraph a l -> m (ClassId, EGraph a l)
EG.representM
{-# INLINE representM #-}

-- | Like 'add', but for a monadic analysis
addM :: (AM.AnalysisM m anl l, Language l) => ENode l -> EGraphMT anl l m ClassId
addM :: forall (m :: * -> *) anl (l :: * -> *).
(AnalysisM m anl l, Language l) =>
ENode l -> EGraphMT anl l m ClassId
addM = (EGraph anl l -> m (ClassId, EGraph anl l))
-> StateT (EGraph anl l) m ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph anl l -> m (ClassId, EGraph anl l))
 -> StateT (EGraph anl l) m ClassId)
-> (ENode l -> EGraph anl l -> m (ClassId, EGraph anl l))
-> ENode l
-> StateT (EGraph anl l) m ClassId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ENode l -> EGraph anl l -> m (ClassId, EGraph anl l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ENode l -> EGraph a l -> m (ClassId, EGraph a l)
EG.addM
{-# INLINE addM #-}

-- | Like 'merge', but for a monadic analysis
mergeM :: (AM.AnalysisM m anl l, Language l) => ClassId -> ClassId -> EGraphMT anl l m ClassId
mergeM :: forall (m :: * -> *) anl (l :: * -> *).
(AnalysisM m anl l, Language l) =>
ClassId -> ClassId -> EGraphMT anl l m ClassId
mergeM ClassId
a ClassId
b = (EGraph anl l -> m (ClassId, EGraph anl l))
-> StateT (EGraph anl l) m ClassId
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT (ClassId -> ClassId -> EGraph anl l -> m (ClassId, EGraph anl l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
ClassId -> ClassId -> EGraph a l -> m (ClassId, EGraph a l)
EG.mergeM ClassId
a ClassId
b)
{-# INLINE mergeM #-}

-- | Like 'rebuild', but for a monadic analysis
rebuildM :: (AM.AnalysisM m anl l, Language l) => EGraphMT anl l m ()
rebuildM :: forall (m :: * -> *) anl (l :: * -> *).
(AnalysisM m anl l, Language l) =>
EGraphMT anl l m ()
rebuildM = (EGraph anl l -> m ((), EGraph anl l))
-> StateT (EGraph anl l) m ()
forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT ((EGraph anl l -> ((), EGraph anl l))
-> m (EGraph anl l) -> m ((), EGraph anl l)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((),) (m (EGraph anl l) -> m ((), EGraph anl l))
-> (EGraph anl l -> m (EGraph anl l))
-> EGraph anl l
-> m ((), EGraph anl l)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EGraph anl l -> m (EGraph anl l)
forall a (l :: * -> *) (m :: * -> *).
(AnalysisM m a l, Language l) =>
EGraph a l -> m (EGraph a l)
EG.rebuildM)
{-# INLINE rebuildM #-}