{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Grisette.Internal.Core.Data.Class.SafeLinearArith
( ArithException (..),
SafeLinearArith (..),
)
where
import Control.Exception (ArithException (DivideByZero, Overflow, Underflow))
import Control.Monad.Except (MonadError (throwError))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.LogicalOp
( LogicalOp ((.&&), (.||)),
)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SEq (SEq ((./=), (.==)))
import Grisette.Internal.Core.Data.Class.SOrd (SOrd ((.<), (.>), (.>=)))
import Grisette.Internal.Core.Data.Class.SimpleMergeable
( mrgIf,
)
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.Core.Data.Class.TryMerge
( TryMerge,
mrgSingle,
)
import Grisette.Internal.SymPrim.BV
( IntN,
WordN,
)
import Grisette.Internal.SymPrim.SymBV
( SymIntN,
SymWordN,
)
import Grisette.Internal.SymPrim.SymInteger (SymInteger)
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)
class (MonadError e m, TryMerge m, Mergeable a) => SafeLinearArith e a m where
safeAdd :: a -> a -> m a
safeNeg :: a -> m a
safeSub :: a -> a -> m a
instance
(MonadError ArithException m, TryMerge m) =>
SafeLinearArith ArithException Integer m
where
safeAdd :: Integer -> Integer -> m Integer
safeAdd Integer
l Integer
r = Integer -> m Integer
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (Integer
l Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
r)
safeNeg :: Integer -> m Integer
safeNeg Integer
l = Integer -> m Integer
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (-Integer
l)
safeSub :: Integer -> Integer -> m Integer
safeSub Integer
l Integer
r = Integer -> m Integer
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (Integer
l Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
r)
#define SAFE_LINARITH_SIGNED_CONCRETE_BODY \
safeAdd l r = let res = l + r in \
if l > 0 && r > 0 && res < 0 \
then mrgThrowError Overflow \
else if l < 0 && r < 0 && res >= 0 \
then mrgThrowError Underflow \
else mrgReturn res;\
safeSub l r = let res = l - r in \
if l >= 0 && r < 0 && res < 0 \
then mrgThrowError Overflow \
else if l < 0 && r > 0 && res > 0 \
then mrgThrowError Underflow \
else mrgReturn res;\
safeNeg v = if v == minBound then mrgThrowError Overflow else mrgReturn $ -v
#define SAFE_LINARITH_SIGNED_CONCRETE(type) \
instance \
(MonadError ArithException m, TryMerge m) => \
SafeLinearArith ArithException type m \
where \
SAFE_LINARITH_SIGNED_CONCRETE_BODY
#define SAFE_LINARITH_SIGNED_BV_CONCRETE(type) \
instance \
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) => \
SafeLinearArith ArithException (type n) m \
where \
SAFE_LINARITH_SIGNED_CONCRETE_BODY
#define SAFE_LINARITH_UNSIGNED_CONCRETE_BODY \
safeAdd l r = let res = l + r in \
if l > res || r > res \
then mrgThrowError Overflow \
else mrgReturn res;\
safeSub l r = \
if r > l \
then mrgThrowError Underflow \
else mrgReturn $ l - r;\
safeNeg v = if v /= 0 then mrgThrowError Underflow else mrgReturn $ -v
#define SAFE_LINARITH_UNSIGNED_CONCRETE(type) \
instance \
(MonadError ArithException m, TryMerge m) => \
SafeLinearArith ArithException type m \
where \
SAFE_LINARITH_UNSIGNED_CONCRETE_BODY
#define SAFE_LINARITH_UNSIGNED_BV_CONCRETE(type) \
instance \
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) => \
SafeLinearArith ArithException (type n) m \
where \
SAFE_LINARITH_UNSIGNED_CONCRETE_BODY
#if 1
SAFE_LINARITH_SIGNED_CONCRETE(Int8)
SAFE_LINARITH_SIGNED_CONCRETE(Int16)
SAFE_LINARITH_SIGNED_CONCRETE(Int32)
SAFE_LINARITH_SIGNED_CONCRETE(Int64)
SAFE_LINARITH_SIGNED_CONCRETE(Int)
SAFE_LINARITH_SIGNED_BV_CONCRETE(IntN)
SAFE_LINARITH_UNSIGNED_CONCRETE(Word8)
SAFE_LINARITH_UNSIGNED_CONCRETE(Word16)
SAFE_LINARITH_UNSIGNED_CONCRETE(Word32)
SAFE_LINARITH_UNSIGNED_CONCRETE(Word64)
SAFE_LINARITH_UNSIGNED_CONCRETE(Word)
SAFE_LINARITH_UNSIGNED_BV_CONCRETE(WordN)
#endif
instance
(MonadError ArithException m, TryMerge m) =>
SafeLinearArith ArithException SymInteger m
where
safeAdd :: SymInteger -> SymInteger -> m SymInteger
safeAdd SymInteger
ls SymInteger
rs = SymInteger -> m SymInteger
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (SymInteger -> m SymInteger) -> SymInteger -> m SymInteger
forall a b. (a -> b) -> a -> b
$ SymInteger
ls SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
+ SymInteger
rs
safeNeg :: SymInteger -> m SymInteger
safeNeg SymInteger
v = SymInteger -> m SymInteger
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (SymInteger -> m SymInteger) -> SymInteger -> m SymInteger
forall a b. (a -> b) -> a -> b
$ -SymInteger
v
safeSub :: SymInteger -> SymInteger -> m SymInteger
safeSub SymInteger
ls SymInteger
rs = SymInteger -> m SymInteger
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (SymInteger -> m SymInteger) -> SymInteger -> m SymInteger
forall a b. (a -> b) -> a -> b
$ SymInteger
ls SymInteger -> SymInteger -> SymInteger
forall a. Num a => a -> a -> a
- SymInteger
rs
instance
(MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
SafeLinearArith ArithException (SymIntN n) m
where
safeAdd :: SymIntN n -> SymIntN n -> m (SymIntN n)
safeAdd SymIntN n
ls SymIntN n
rs =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
ls SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymIntN n
0)
(SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymIntN n
rs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
res SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0) (ArithException -> m (SymIntN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Overflow) (SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return SymIntN n
res))
( SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
ls SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
rs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
res SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
0)
(ArithException -> m (SymIntN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Underflow)
(SymIntN n -> m (SymIntN n)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle SymIntN n
res)
)
where
res :: SymIntN n
res = SymIntN n
ls SymIntN n -> SymIntN n -> SymIntN n
forall a. Num a => a -> a -> a
+ SymIntN n
rs
safeNeg :: SymIntN n -> m (SymIntN n)
safeNeg SymIntN n
v = SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymIntN n
v SymIntN n -> SymIntN n -> SymBool
forall a. SEq a => a -> a -> SymBool
.== IntN n -> SymIntN n
forall c t. Solvable c t => c -> t
con IntN n
forall a. Bounded a => a
minBound) (ArithException -> m (SymIntN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Overflow) (SymIntN n -> m (SymIntN n)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ -SymIntN n
v)
safeSub :: SymIntN n -> SymIntN n -> m (SymIntN n)
safeSub SymIntN n
ls SymIntN n
rs =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
ls SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
0)
(SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymIntN n
rs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
res SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0) (ArithException -> m (SymIntN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Overflow) (SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return SymIntN n
res))
( SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
ls SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
rs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
res SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymIntN n
0)
(ArithException -> m (SymIntN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Underflow)
(SymIntN n -> m (SymIntN n)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle SymIntN n
res)
)
where
res :: SymIntN n
res = SymIntN n
ls SymIntN n -> SymIntN n -> SymIntN n
forall a. Num a => a -> a -> a
- SymIntN n
rs
instance
(MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
SafeLinearArith ArithException (SymWordN n) m
where
safeAdd :: SymWordN n -> SymWordN n -> m (SymWordN n)
safeAdd SymWordN n
ls SymWordN n
rs =
SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymWordN n
ls SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymWordN n
res SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.|| SymWordN n
rs SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymWordN n
res)
(ArithException -> m (SymWordN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Overflow)
(SymWordN n -> m (SymWordN n)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle SymWordN n
res)
where
res :: SymWordN n
res = SymWordN n
ls SymWordN n -> SymWordN n -> SymWordN n
forall a. Num a => a -> a -> a
+ SymWordN n
rs
safeNeg :: SymWordN n -> m (SymWordN n)
safeNeg SymWordN n
v = SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf (SymWordN n
v SymWordN n -> SymWordN n -> SymBool
forall a. SEq a => a -> a -> SymBool
./= SymWordN n
0) (ArithException -> m (SymWordN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Underflow) (SymWordN n -> m (SymWordN n)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle SymWordN n
v)
safeSub :: SymWordN n -> SymWordN n -> m (SymWordN n)
safeSub SymWordN n
ls SymWordN n
rs =
SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymWordN n
rs SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.> SymWordN n
ls)
(ArithException -> m (SymWordN n)
forall a. ArithException -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ArithException
Underflow)
(SymWordN n -> m (SymWordN n)
forall (m :: * -> *) a.
(TryMerge m, Applicative m, Mergeable a) =>
a -> m a
mrgSingle SymWordN n
res)
where
res :: SymWordN n
res = SymWordN n
ls SymWordN n -> SymWordN n -> SymWordN n
forall a. Num a => a -> a -> a
- SymWordN n
rs