{-# LANGUAGE Trustworthy #-}

module Language.Hasmtlib.Internal.Sharing
  ( Sharing(..)
  , SharingMode(..)
  , runSharing, share
  )
where

import Language.Hasmtlib.Internal.Uniplate1
import Language.Hasmtlib.Type.MonadSMT
import Language.Hasmtlib.Type.SMTSort
import Language.Hasmtlib.Type.Expr
import Data.GADT.Compare
import Data.HashMap.Lazy
import Data.Default
import Data.Kind
import Control.Monad.State
import Control.Lens
import System.Mem.StableName
import System.IO.Unsafe
import Unsafe.Coerce

-- | Mode used for sharing.
data SharingMode =
    None            -- ^ Common expressions are not shared at all
  | StableNames     -- ^ Expressions that resolve to the same 'StableName' are shared
  deriving Int -> SharingMode -> ShowS
[SharingMode] -> ShowS
SharingMode -> String
(Int -> SharingMode -> ShowS)
-> (SharingMode -> String)
-> ([SharingMode] -> ShowS)
-> Show SharingMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SharingMode -> ShowS
showsPrec :: Int -> SharingMode -> ShowS
$cshow :: SharingMode -> String
show :: SharingMode -> String
$cshowList :: [SharingMode] -> ShowS
showList :: [SharingMode] -> ShowS
Show

instance Default SharingMode where
  def :: SharingMode
def = SharingMode
None

-- | States that can share expressions by comparing their 'StableName's.
class Sharing s where
  -- | A constraint on the monad used when asserting the shared node in 'assertSharedNode'.
  type SharingMonad s :: (Type -> Type) -> Constraint

  -- | A 'Lens'' on a mapping between a 'StableName' and it's 'Expr' we may share.
  stableMap :: Lens' s (HashMap (StableName ()) (SomeKnownSMTSort Expr))

  -- | Asserts that a node-expression is represented by it's auxiliary node-variable: @nodeExpr :: Expr t === nodeVar@.
  --   Also gives access to the 'StableName' of the original expression.
  assertSharedNode :: (MonadState s m, SharingMonad s m) => StableName () -> Expr BoolSort -> m ()

  -- | Sets the mode used for sharing common expressions. Defaults to 'StableNames'.
  setSharingMode :: MonadState s m => SharingMode -> m ()

-- | Shares all possible sub-expressions in given expression.
--   Replaces each node in the expression-tree with an auxiliary variable.
--   All nodes @x@ @y@ where @makeStableName x == makeStableName y@ are replaced with the same auxiliary variable.
--   Therefore this creates a DAG.
runSharing :: (KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) => SharingMode -> Expr t -> m (Expr t)
runSharing :: forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) =>
SharingMode -> Expr t -> m (Expr t)
runSharing SharingMode
None = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
runSharing SharingMode
StableNames = (forall (a :: SMTSort).
 AllC '[KnownSMTSort] a =>
 Expr a -> m (Expr a) -> m (Expr a))
-> Expr t -> m (Expr t)
forall {k} (m :: * -> *) (f :: k -> *) (cs :: [k -> Constraint])
       (b :: k).
(Monad m, Uniplate1 f cs, AllC cs b) =>
(forall (a :: k). AllC cs a => f a -> m (f a) -> m (f a))
-> f b -> m (f b)
lazyParaM1 (
    \Expr a
origExpr m (Expr a)
expr ->
      if Expr a -> Bool
forall (t :: SMTSort). Expr t -> Bool
isLeaf Expr a
origExpr
      then Expr a -> m (Expr a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr a
origExpr
      else case Expr a -> SSMTSort a
forall (prxy :: SMTSort -> *) (t :: SMTSort).
KnownSMTSort t =>
prxy t -> SSMTSort t
sortSing' Expr a
origExpr of   -- scopes Equatable (Expr t) for specific t
        SSMTSort a
SBoolSort      -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SSMTSort a
SIntSort       -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SSMTSort a
SRealSort      -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SBvSort Proxy enc
_ Proxy n
_    -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SArraySort Proxy k
_ Proxy v
_ -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr
        SSMTSort a
SStringSort    -> Expr a -> m (Expr a) -> m (Expr a)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share Expr a
origExpr m (Expr a)
expr)

-- | Returns an auxiliary variable representing this expression node.
--   If such a shared auxiliary variable exists already, returns that.
--   Otherwise creates one and returns it.
share :: (Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) => Expr t -> m (Expr t) -> m (Expr t)
share :: forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
Expr t -> m (Expr t) -> m (Expr t)
share expr :: Expr t
expr@(ForAll Maybe (SMTVar t1)
_ Expr t1 -> Expr 'BoolSort
_) m (Expr t)
_ = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
expr     -- sharing quantified expression would out-scope quantified var
share expr :: Expr t
expr@(Exists Maybe (SMTVar t1)
_ Expr t1 -> Expr 'BoolSort
_) m (Expr t)
_ = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
expr
share Expr t
origExpr m (Expr t)
expr = do
  let sn :: StableName ()
sn = IO (StableName ()) -> StableName ()
forall a. IO a -> a
unsafePerformIO (Expr t -> IO (StableName ())
forall a. a -> IO (StableName ())
makeStableName' Expr t
origExpr)
   in Getting
  (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
  s
  (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
-> m (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
forall s (m :: * -> *) a. MonadState s m => Getting a s a -> m a
use ((HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)
 -> Const
      (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
      (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)))
-> s -> Const (Maybe (SomeSMTSort '[KnownSMTSort] Expr)) s
forall s.
Sharing s =>
Lens'
  s (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
Lens'
  s (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
stableMap((HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)
  -> Const
       (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
       (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)))
 -> s -> Const (Maybe (SomeSMTSort '[KnownSMTSort] Expr)) s)
-> ((Maybe (SomeSMTSort '[KnownSMTSort] Expr)
     -> Const
          (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
          (Maybe (SomeSMTSort '[KnownSMTSort] Expr)))
    -> HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)
    -> Const
         (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
         (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)))
-> Getting
     (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
     s
     (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Index (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
-> Lens'
     (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
     (Maybe
        (IxValue
           (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
at StableName ()
Index (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
sn) m (Maybe (SomeSMTSort '[KnownSMTSort] Expr))
-> (Maybe (SomeSMTSort '[KnownSMTSort] Expr) -> m (Expr t))
-> m (Expr t)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe (SomeSMTSort '[KnownSMTSort] Expr)
mexpr' -> case Maybe (SomeSMTSort '[KnownSMTSort] Expr)
mexpr' of
        Just (SomeSMTSort Expr t
expr') -> case SSMTSort t -> SSMTSort t -> Maybe (t :~: t)
forall k (f :: k -> *) (a :: k) (b :: k).
GEq f =>
f a -> f b -> Maybe (a :~: b)
forall (a :: SMTSort) (b :: SMTSort).
SSMTSort a -> SSMTSort b -> Maybe (a :~: b)
geq (Expr t -> SSMTSort t
forall (prxy :: SMTSort -> *) (t :: SMTSort).
KnownSMTSort t =>
prxy t -> SSMTSort t
sortSing' Expr t
origExpr) (Expr t -> SSMTSort t
forall (prxy :: SMTSort -> *) (t :: SMTSort).
KnownSMTSort t =>
prxy t -> SSMTSort t
sortSing' Expr t
expr') of
          Maybe (t :~: t)
Nothing -> m (Expr t)
expr m (Expr t) -> (Expr t -> m (Expr t)) -> m (Expr t)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StableName () -> Expr t -> m (Expr t)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
StableName () -> Expr t -> m (Expr t)
makeNode StableName ()
sn
          Just t :~: t
Refl -> Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
Expr t
expr'
        Maybe (SomeSMTSort '[KnownSMTSort] Expr)
Nothing -> m (Expr t)
expr m (Expr t) -> (Expr t -> m (Expr t)) -> m (Expr t)
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= StableName () -> Expr t -> m (Expr t)
forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
StableName () -> Expr t -> m (Expr t)
makeNode StableName ()
sn

makeNode :: (Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s, SharingMonad s m) => StableName () -> Expr t -> m (Expr t)
makeNode :: forall (t :: SMTSort) s (m :: * -> *).
(Equatable (Expr t), KnownSMTSort t, MonadSMT s m, Sharing s,
 SharingMonad s m) =>
StableName () -> Expr t -> m (Expr t)
makeNode StableName ()
sn Expr t
nodeExpr = do
  Expr t
nodeVar <- m (Expr t)
forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m) =>
m (Expr t)
var
  StableName () -> Expr 'BoolSort -> m ()
forall s (m :: * -> *).
(Sharing s, MonadState s m, SharingMonad s m) =>
StableName () -> Expr 'BoolSort -> m ()
forall (m :: * -> *).
(MonadState s m, SharingMonad s m) =>
StableName () -> Expr 'BoolSort -> m ()
assertSharedNode StableName ()
sn (Expr 'BoolSort -> m ()) -> Expr 'BoolSort -> m ()
forall a b. (a -> b) -> a -> b
$ Expr t
nodeVar Expr t -> Expr t -> Expr 'BoolSort
forall a. Equatable a => a -> a -> Expr 'BoolSort
=== Expr t
nodeExpr
  (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)
 -> Identity
      (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)))
-> s -> Identity s
forall s.
Sharing s =>
Lens'
  s (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
Lens'
  s (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
stableMap((HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)
  -> Identity
       (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)))
 -> s -> Identity s)
-> ((Maybe (SomeSMTSort '[KnownSMTSort] Expr)
     -> Identity (Maybe (SomeSMTSort '[KnownSMTSort] Expr)))
    -> HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)
    -> Identity
         (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr)))
-> (Maybe (SomeSMTSort '[KnownSMTSort] Expr)
    -> Identity (Maybe (SomeSMTSort '[KnownSMTSort] Expr)))
-> s
-> Identity s
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Index (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
-> Lens'
     (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
     (Maybe
        (IxValue
           (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))))
forall m. At m => Index m -> Lens' m (Maybe (IxValue m))
at StableName ()
Index (HashMap (StableName ()) (SomeSMTSort '[KnownSMTSort] Expr))
sn ((Maybe (SomeSMTSort '[KnownSMTSort] Expr)
  -> Identity (Maybe (SomeSMTSort '[KnownSMTSort] Expr)))
 -> s -> Identity s)
-> SomeSMTSort '[KnownSMTSort] Expr -> m ()
forall s (m :: * -> *) a b.
MonadState s m =>
ASetter s s a (Maybe b) -> b -> m ()
?= Expr t -> SomeSMTSort '[KnownSMTSort] Expr
forall (cs :: [SMTSort -> Constraint]) (f :: SMTSort -> *)
       (t :: SMTSort).
AllC cs t =>
f t -> SomeSMTSort cs f
SomeSMTSort Expr t
nodeVar
  Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
nodeVar

makeStableName' :: a -> IO (StableName ())
makeStableName' :: forall a. a -> IO (StableName ())
makeStableName' a
x = a
x a -> IO (StableName ()) -> IO (StableName ())
forall a b. a -> b -> b
`seq` (StableName a -> StableName ())
-> IO (StableName a) -> IO (StableName ())
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap StableName a -> StableName ()
forall a b. a -> b
unsafeCoerce (a -> IO (StableName a)
forall a. a -> IO (StableName a)
makeStableName a
x)