{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Rank2Types #-}
{-|
  Hand-rolled lenses on e-graphs and e-classes which come in quite handy, are
  heavily used in 'Data.Equality.Graph', and are the only exported way of
  editing the structure of the e-graph. If you want to write some complex
  'Analysis' you'll probably need these.
 -}
module Data.Equality.Graph.Lens where

import qualified Data.IntMap.Strict as IM
import qualified Data.Set as S

import Data.Functor.Identity
import Data.Functor.Const
import Data.Monoid

import Data.Equality.Utils.SizedList
import Data.Equality.Graph.Internal
import Data.Equality.Graph.Classes.Id
import Data.Equality.Graph.Nodes
import Data.Equality.Graph.Classes
import Data.Equality.Graph.ReprUnionFind

-- | A 'Lens'' as defined in lens
type Lens' s a = forall f. Functor f => (a -> f a) -> (s -> f s)
-- | A 'Lens' as defined in lens
type Lens s t a b = forall f. Functor f => (a -> f b) -> (s -> f t)
-- | A 'Traversal' as defined in lens
type Traversal s t a b = forall f. Applicative f => (a -> f b) -> (s -> f t)

-- outdated comment for "getClass":
--
-- Get an e-class from an e-graph given its e-class id
--
-- Returns the canonical id of the class and the class itself
--
-- We'll find its canonical representation and then get it from the e-classes map
--
-- Invariant: The e-class exists.

-- | Lens for the e-class at the representative of the given id in an e-graph
--
-- Calls 'error' when the e-class doesn't exist
_class :: ClassId -> Lens' (EGraph a l) (EClass a l)
_class :: forall a (l :: * -> *). ClassId -> Lens' (EGraph a l) (EClass a l)
_class ClassId
i EClass a l -> f (EClass a l)
afa EGraph a l
s =
    let canon_id :: ClassId
canon_id = ClassId -> ReprUnionFind -> ClassId
findRepr ClassId
i (EGraph a l -> ReprUnionFind
forall analysis (language :: * -> *).
EGraph analysis language -> ReprUnionFind
unionFind EGraph a l
s)
     in (\EClass a l
c' -> EGraph a l
s { classes = IM.insert canon_id c' (classes s) }) (EClass a l -> EGraph a l) -> f (EClass a l) -> f (EGraph a l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> EClass a l -> f (EClass a l)
afa (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
s IntMap (EClass a l) -> ClassId -> EClass a l
forall a. IntMap a -> ClassId -> a
IM.! ClassId
canon_id)
{-# INLINE _class #-}

-- | Lens for the memo of e-nodes in an e-graph, that is, a mapping from
-- e-nodes to the e-class they're represented in
_memo :: Lens' (EGraph a l) (NodeMap l ClassId)
_memo :: forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l ClassId -> f (NodeMap l ClassId))
-> EGraph a l -> f (EGraph a l)
_memo NodeMap l ClassId -> f (NodeMap l ClassId)
afa EGraph a l
egr = (\NodeMap l ClassId
m1 -> EGraph a l
egr {memo = m1}) (NodeMap l ClassId -> EGraph a l)
-> f (NodeMap l ClassId) -> f (EGraph a l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NodeMap l ClassId -> f (NodeMap l ClassId)
afa (EGraph a l -> NodeMap l ClassId
forall analysis (language :: * -> *).
EGraph analysis language -> Memo language
memo EGraph a l
egr)
{-# INLINE _memo #-}

-- | Traversal for the existing classes in an e-graph
_classes :: Traversal (EGraph a l) (EGraph b l) (EClass a l) (EClass b l)
_classes :: forall a (l :: * -> *) b (f :: * -> *).
Applicative f =>
(EClass a l -> f (EClass b l)) -> EGraph a l -> f (EGraph b l)
_classes EClass a l -> f (EClass b l)
afb EGraph a l
egr = (\ClassIdMap (EClass b l)
m1 -> EGraph a l
egr {classes = m1}) (ClassIdMap (EClass b l) -> EGraph b l)
-> f (ClassIdMap (EClass b l)) -> f (EGraph b l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (EClass a l -> f (EClass b l))
-> IntMap (EClass a l) -> f (ClassIdMap (EClass b l))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IntMap a -> f (IntMap b)
traverse EClass a l -> f (EClass b l)
afb (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
{-# INLINE _classes #-}

-- | Traversal for the existing classes in an e-graph
_iclasses :: Traversal (EGraph a l) (EGraph b l) (ClassId, EClass a l) (EClass b l)
_iclasses :: forall a (l :: * -> *) b (f :: * -> *).
Applicative f =>
((ClassId, EClass a l) -> f (EClass b l))
-> EGraph a l -> f (EGraph b l)
_iclasses (ClassId, EClass a l) -> f (EClass b l)
afb EGraph a l
egr = (\ClassIdMap (EClass b l)
m1 -> EGraph a l
egr {classes = m1}) (ClassIdMap (EClass b l) -> EGraph b l)
-> f (ClassIdMap (EClass b l)) -> f (EGraph b l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ClassId -> EClass a l -> f (EClass b l))
-> IntMap (EClass a l) -> f (ClassIdMap (EClass b l))
forall (t :: * -> *) a b.
Applicative t =>
(ClassId -> a -> t b) -> IntMap a -> t (IntMap b)
IM.traverseWithKey (((ClassId, EClass a l) -> f (EClass b l))
-> ClassId -> EClass a l -> f (EClass b l)
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (ClassId, EClass a l) -> f (EClass b l)
afb) (EGraph a l -> IntMap (EClass a l)
forall analysis (language :: * -> *).
EGraph analysis language -> ClassIdMap (EClass analysis language)
classes EGraph a l
egr)
{-# INLINE _iclasses #-}

-- | Lens for the 'Domain' of an e-class
_data :: Lens (EClass domain l) (EClass domain' l) domain domain'
_data :: forall domain (l :: * -> *) domain' (f :: * -> *).
Functor f =>
(domain -> f domain') -> EClass domain l -> f (EClass domain' l)
_data domain -> f domain'
afb EClass{domain
ClassId
Set (ENode l)
SList (ClassId, ENode l)
eClassId :: ClassId
eClassNodes :: Set (ENode l)
eClassData :: domain
eClassParents :: SList (ClassId, ENode l)
eClassParents :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassData :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassNodes :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassId :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> ClassId
..} = (\domain'
d1 -> ClassId
-> Set (ENode l)
-> domain'
-> SList (ClassId, ENode l)
-> EClass domain' l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
eClassId Set (ENode l)
eClassNodes domain'
d1 SList (ClassId, ENode l)
eClassParents) (domain' -> EClass domain' l) -> f domain' -> f (EClass domain' l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> domain -> f domain'
afb domain
eClassData
{-# INLINE _data #-}

-- | Lens for the parent e-classes of an e-class
_parents :: Lens' (EClass a l) (SList (ClassId, ENode l))
_parents :: forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(SList (ClassId, ENode l) -> f (SList (ClassId, ENode l)))
-> EClass a l -> f (EClass a l)
_parents SList (ClassId, ENode l) -> f (SList (ClassId, ENode l))
afa EClass{a
ClassId
Set (ENode l)
SList (ClassId, ENode l)
eClassParents :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassData :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassNodes :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassId :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> ClassId
eClassId :: ClassId
eClassNodes :: Set (ENode l)
eClassData :: a
eClassParents :: SList (ClassId, ENode l)
..} = ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
eClassId Set (ENode l)
eClassNodes a
eClassData (SList (ClassId, ENode l) -> EClass a l)
-> f (SList (ClassId, ENode l)) -> f (EClass a l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SList (ClassId, ENode l) -> f (SList (ClassId, ENode l))
afa SList (ClassId, ENode l)
eClassParents
{-# INLINE _parents #-}

-- | Lens for the e-nodes in an e-class
_nodes :: Lens' (EClass a l) (S.Set (ENode l))
_nodes :: forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(Set (ENode l) -> f (Set (ENode l)))
-> EClass a l -> f (EClass a l)
_nodes Set (ENode l) -> f (Set (ENode l))
afa EClass{a
ClassId
Set (ENode l)
SList (ClassId, ENode l)
eClassParents :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> SList (ClassId, ENode language)
eClassData :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> analysis_domain
eClassNodes :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> Set (ENode language)
eClassId :: forall analysis_domain (language :: * -> *).
EClass analysis_domain language -> ClassId
eClassId :: ClassId
eClassNodes :: Set (ENode l)
eClassData :: a
eClassParents :: SList (ClassId, ENode l)
..} = (\Set (ENode l)
ns -> ClassId
-> Set (ENode l) -> a -> SList (ClassId, ENode l) -> EClass a l
forall analysis_domain (language :: * -> *).
ClassId
-> Set (ENode language)
-> analysis_domain
-> SList (ClassId, ENode language)
-> EClass analysis_domain language
EClass ClassId
eClassId Set (ENode l)
ns a
eClassData SList (ClassId, ENode l)
eClassParents) (Set (ENode l) -> EClass a l)
-> f (Set (ENode l)) -> f (EClass a l)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set (ENode l) -> f (Set (ENode l))
afa Set (ENode l)
eClassNodes
{-# INLINE _nodes #-}

-- | Like @'view'@ but with the arguments flipped
(^.) :: s -> Lens' s a -> a
^. :: forall s a. s -> Lens' s a -> a
(^.) s
s Lens' s a
ln = Lens' s a -> s -> a
forall s a. Lens' s a -> s -> a
view (a -> f a) -> s -> f s
Lens' s a
ln s
s
infixl 8 ^.
{-# INLINE (^.) #-}

-- | Synonym for @'set'@
(.~) :: Lens' s a -> a -> (s -> s)
.~ :: forall s a. Lens' s a -> a -> s -> s
(.~) = Lens' s a -> a -> s -> s
forall s a. Lens' s a -> a -> s -> s
set
infixr 4 .~
{-# INLINE (.~) #-}

-- | Synonym for @'over'@
(%~) :: ASetter s t a b -> (a -> b) -> (s -> t)
%~ :: forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
(%~) = ASetter s t a b -> (a -> b) -> s -> t
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over
infixr 4 %~
{-# INLINE (%~) #-}

-- | Applies a getter to a value
view :: Lens' s a -> (s -> a)
view :: forall s a. Lens' s a -> s -> a
view Lens' s a
ln = Const a s -> a
forall {k} a (b :: k). Const a b -> a
getConst (Const a s -> a) -> (s -> Const a s) -> s -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Const a a) -> s -> Const a s
Lens' s a
ln a -> Const a a
forall {k} a (b :: k). a -> Const a b
Const
{-# INLINE view #-}

-- | Applies a setter to a value
set :: Lens' s a -> a -> (s -> s)
set :: forall s a. Lens' s a -> a -> s -> s
set Lens' s a
ln a
x = ASetter s s a a -> (a -> a) -> s -> s
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter s s a a
Lens' s a
ln (a -> a -> a
forall a b. a -> b -> a
const a
x)
{-# INLINE set #-}

-- | Applies a function to the target
over :: ASetter s t a b -> (a -> b) -> (s -> t)
over :: forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
over ASetter s t a b
ln a -> b
f = Identity t -> t
forall a. Identity a -> a
runIdentity (Identity t -> t) -> (s -> Identity t) -> s -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ASetter s t a b
ln (b -> Identity b
forall a. a -> Identity a
Identity (b -> Identity b) -> (a -> b) -> a -> Identity b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)
{-# INLINE over #-}

-- | Basically 'traverse' over a 'Traversal'
traverseOf :: Traversal s t a b -> forall f. Applicative f => (a -> f b) -> s -> f t 
traverseOf :: forall s t a b. Traversal s t a b -> Traversal s t a b
traverseOf Traversal s t a b
t = (a -> f b) -> s -> f t
Traversal s t a b
t
{-# INLINE traverseOf #-}

-- | Returns True if every target of a Traversable satisfies a predicate.
allOf :: Traversal s t a b -> (a -> Bool) -> s -> Bool
allOf :: forall s t a b. Traversal s t a b -> (a -> Bool) -> s -> Bool
allOf Traversal s t a b
trv a -> Bool
f = All -> Bool
getAll (All -> Bool) -> (s -> All) -> s -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Const All t -> All
forall {k} a (b :: k). Const a b -> a
getConst (Const All t -> All) -> (s -> Const All t) -> s -> All
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Const All b) -> s -> Const All t
Traversal s t a b
trv (All -> Const All b
forall {k} a (b :: k). a -> Const a b
Const (All -> Const All b) -> (a -> All) -> a -> Const All b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> All
All (Bool -> All) -> (a -> Bool) -> a -> All
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Bool
f)
{-# INLINE allOf #-}

-- * Utilities

-- We need to use 'ASetter' instead of 'Lens' in %~ to ensure type inference can
-- figure out the Functor or Applicative is 'Identity'. Otherwise, we won't be
-- able to use the 'Traversal' to modify something through a 'Lens'.

-- | Used instead of 'Lens' in 'over' and '%~' to ensure one can call those
-- combinators on 'Lens's and 'Traversal's. Essentially, it helps type
-- inference in such function applications
type ASetter s t a b = (a -> Identity b) -> s -> Identity t