{-# LANGUAGE Trustworthy #-}

module Language.Hasmtlib.Internal.Sharing
  ( Sharing(..)
  , SharingMode(..)
  , runSharing, share
  )
where

import Language.Hasmtlib.Internal.Uniplate1
import Language.Hasmtlib.Type.MonadSMT
import Language.Hasmtlib.Type.SMTSort
import Language.Hasmtlib.Type.Expr
import Data.GADT.Compare
import Data.Some.Constraint
import Data.HashMap.Lazy
import Data.Default
import Data.Kind
import Control.Monad.State
import Control.Lens
import System.Mem.StableName
import System.IO.Unsafe
import Unsafe.Coerce

-- | Mode used for sharing.
data SharingMode =
    None            -- ^ Common expressions are not shared at all
  | StableNames     -- ^ Expressions that resolve to the same 'StableName' are shared
  deriving Int -> SharingMode -> ShowS
[SharingMode] -> ShowS
SharingMode -> String
(Int -> SharingMode -> ShowS)
-> (SharingMode -> String)
-> ([SharingMode] -> ShowS)
-> Show SharingMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SharingMode -> ShowS
showsPrec :: Int -> SharingMode -> ShowS
$cshow :: SharingMode -> String
show :: SharingMode -> String
$cshowList :: [SharingMode] -> ShowS
showList :: [SharingMode] -> ShowS
Show

instance Default SharingMode where
  def :: SharingMode
def = SharingMode
None

-- | States that can share expressions by comparing their 'StableName's.
class Sharing s where
  -- | A constraint on the monad used when asserting the shared node in 'assertSharedNode'.
  type SharingMonad s :: (Type -> Type) -> Constraint

  -- | A 'Lens'' on a mapping between a 'StableName' and it's 'Expr' we may share.
  stableMap :: Lens' s (HashMap (StableName ()) (SomeKnownSMTSort Expr))

  -- | Asserts that a node-expression is represented by it's auxiliary node-variable: @nodeExpr :: Expr t === nodeVar@.
  --   Also gives access to the 'StableName' of the original expression.
  assertSharedNode :: (MonadState s m, SharingMonad s m) => StableName () -> Expr BoolSort -> m ()

  -- | Sets the mode used for sharing common expressions. Defaults to 'StableNames'.
  setSharingMode :: MonadState s m => SharingMode -> m ()

-- | Shares all possible sub-expressions in given expression.
--   Replaces each node in the expression-tree with an auxiliary variable.
--   All nodes @x@ @y@ where @makeStableName x == makeStableName y@ are replaced with the same auxiliary variable.
--   Therefore this creates a DAG.
runSharing :: (KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) => SharingMode -> Expr t -> m (Expr t)
runSharing :: forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) =>
SharingMode -> Expr t -> m (Expr t)
runSharing SharingMode
None = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
runSharing SharingMode
StableNames = (forall (a :: SMTSort).
 AllC '[KnownSMTSort] a =>
 Expr a -> m (Expr a) -> m (Expr a))
-> Expr t -> m (Expr t)
forall {k} (m :: * -> *) (f :: k -> *) (cs :: [k -> Constraint])
       (b :: k).
(Monad m, Uniplate1 f cs, AllC cs b) =>
(forall (a :: k). AllC cs a => f a -> m (f a) -> m (f a))
-> f b -> m (f b)
lazyParaM1 (
    \Expr a
origExpr m (Expr a)
expr ->
      if Expr a -> Bool
forall (t :: SMTSort). Expr t -> Bool
isLeaf Expr a
origExpr
      then Expr a -> m (Expr a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
origExpr
      else case Expr a -> SSMTSort a
forall (prxy :: SMTSort -> *) (t :: SMTSort).
KnownSMTSort t =>
prxy t -> SSMTSort t
sortSing' Expr a
origExpr of   -- scopes Equatable (Expr t) for specific t
        SSMTSort a
SBoolSort      -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SSMTSort a
SIntSort       -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SSMTSort a
SRealSort      -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SBvSort Proxy enc
_ Proxy n
_    -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SArraySort Proxy k
_ Proxy v
_ -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SSMTSort a
SStringSort    -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr)

-- | Returns an auxiliary variable representing this expression node.
--   If such a shared auxiliary variable exists already, returns that.
--   Otherwise creates one and returns it.
share :: (Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) => Expr t -> m (Expr t) -> m (Expr t)
share :: forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share expr :: Expr t
expr@(ForAll Maybe (SMTVar t1)
_ Expr t1 -> Expr 'BoolSort
_) m (Expr t)
_ = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
expr     -- sharing quantified expression would out-scope quantified var
share expr :: Expr t
expr@(Exists Maybe (SMTVar t1)
_ Expr t1 -> Expr 'BoolSort
_) m (Expr t)
_ = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
expr
share Expr t
origExpr m (Expr t)
expr = do
  let sn :: StableName ()
sn = IO (StableName ()) -> StableName ()
forall a. IO a -> a
unsafePerformIO (Expr t -> IO (StableName ())
forall a. a -> IO (StableName ())
makeStableName' Expr t
origExpr)
   in Getting
  (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
  s
  (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
-> m (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])
 -> Const
      (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
      (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])))
-> s -> Const (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])) s
forall s.
Sharing s =>
Lens'
  s (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
Lens'
  s (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
stableMap((HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])
  -> Const
       (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
       (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])))
 -> s -> Const (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])) s)
-> ((Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
     -> Const
          (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
          (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])))
    -> HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])
    -> Const
         (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
         (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])))
-> Getting
     (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
     s
     (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Index
  (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
-> Lens'
     (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
     (Maybe
        (IxValue
           (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
at StableName ()
Index
  (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
sn) m (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]))
-> (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort]) -> m (Expr t))
-> m (Expr t)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
mexpr' -> case Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
mexpr' of
        Just (Some1 f a
expr') -> case SSMTSort t -> SSMTSort a -> Maybe (t :~: a)
forall k (f :: k -> *) (a :: k) (b :: k).
GEq f =>
f a -> f b -> Maybe (a :~: b)
forall (a :: SMTSort) (b :: SMTSort).
SSMTSort a -> SSMTSort b -> Maybe (a :~: b)
geq (Expr t -> SSMTSort t
forall (prxy :: SMTSort -> *) (t :: SMTSort).
KnownSMTSort t =>
prxy t -> SSMTSort t
sortSing' Expr t
origExpr) (Expr a -> SSMTSort a
forall (prxy :: SMTSort -> *) (t :: SMTSort).
KnownSMTSort t =>
prxy t -> SSMTSort t
sortSing' f a
Expr a
expr') of
          Maybe (t :~: a)
Nothing -> m (Expr t)
expr m (Expr t) -> (Expr t -> m (Expr t)) -> m (Expr t)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StableName () -> Expr t -> m (Expr t)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
StableName () -> Expr t -> m (Expr t)
makeNode StableName ()
sn
          Just t :~: a
Refl -> Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return f a
Expr t
expr'
        Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
Nothing -> m (Expr t)
expr m (Expr t) -> (Expr t -> m (Expr t)) -> m (Expr t)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StableName () -> Expr t -> m (Expr t)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
StableName () -> Expr t -> m (Expr t)
makeNode StableName ()
sn

makeNode :: (Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) => StableName () -> Expr t -> m (Expr t)
makeNode :: forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
StableName () -> Expr t -> m (Expr t)
makeNode StableName ()
sn Expr t
nodeExpr = do
  Expr t
nodeVar <- m (Expr t)
forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m) =>
m (Expr t)
var
  StableName () -> Expr 'BoolSort -> m ()
forall s (m :: * -> *).
(Sharing s, MonadState s m, SharingMonad s m) =>
StableName () -> Expr 'BoolSort -> m ()
forall (m :: * -> *).
(MonadState s m, SharingMonad s m) =>
StableName () -> Expr 'BoolSort -> m ()
assertSharedNode StableName ()
sn (Expr 'BoolSort -> m ()) -> Expr 'BoolSort -> m ()
forall a b. (a -> b) -> a -> b
$ Expr t
nodeVar Expr t -> Expr t -> Expr 'BoolSort
forall a. Equatable a => a -> a -> Expr 'BoolSort
=== Expr t
nodeExpr
  (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])
 -> Identity
      (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])))
-> s -> Identity s
forall s.
Sharing s =>
Lens'
  s (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
Lens'
  s (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
stableMap((HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])
  -> Identity
       (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])))
 -> s -> Identity s)
-> ((Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
     -> Identity (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])))
    -> HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])
    -> Identity
         (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort])))
-> (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
    -> Identity (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])))
-> s
-> Identity s
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Index
  (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
-> Lens'
     (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
     (Maybe
        (IxValue
           (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
at StableName ()
Index
  (HashMap (StableName ()) (Somes1 '[(~) Expr] '[KnownSMTSort]))
sn ((Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])
  -> Identity (Maybe (Somes1 '[(~) Expr] '[KnownSMTSort])))
 -> s -> Identity s)
-> Somes1 '[(~) Expr] '[KnownSMTSort] -> m ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a (Maybe b) -> b -> m ()
?= Expr t -> Somes1 '[(~) Expr] '[KnownSMTSort]
forall k (csf :: [(k -> *) -> Constraint])
       (csa :: [k -> Constraint]) (f :: k -> *) (a :: k).
(AllC csf f, AllC csa a) =>
f a -> Somes1 csf csa
Some1 Expr t
nodeVar
  Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
nodeVar

makeStableName' :: a -> IO (StableName ())
makeStableName' :: forall a. a -> IO (StableName ())
makeStableName' a
x = a
x a -> IO (StableName ()) -> IO (StableName ())
forall a b. a -> b -> b
`seq` (StableName a -> StableName ())
-> IO (StableName a) -> IO (StableName ())
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StableName a -> StableName ()
forall a b. a -> b
unsafeCoerce (a -> IO (StableName a)
forall a. a -> IO (StableName a)
makeStableName a
x)