{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Grisette.Core.Data.Class.SafeSymShift
( SafeSymShift (..),
)
where
import Control.Exception (ArithException (Overflow))
import Control.Monad.Error.Class (MonadError)
import Data.Bits (Bits (shiftL, shiftR), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Core.Control.Monad.Union (MonadUnion)
import Grisette.Core.Data.BV (IntN, WordN)
import Grisette.Core.Data.Class.LogicalOp
( LogicalOp ((.&&), (.||)),
)
import Grisette.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Core.Data.Class.SOrd
( SOrd ((.<), (.>=)),
)
import Grisette.Core.Data.Class.SimpleMergeable
( UnionLike,
mrgIf,
)
import Grisette.Core.Data.Class.SymShift (SymShift)
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits
( pevalShiftLeftTerm,
pevalShiftRightTerm,
)
import Grisette.IR.SymPrim.Data.SymPrim (SymIntN (SymIntN), SymWordN (SymWordN))
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)
class (SymShift a) => SafeSymShift e a | a -> e where
safeSymShiftL :: (MonadError e m, UnionLike m) => a -> a -> m a
safeSymShiftL = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymShiftL' e -> e
forall a. a -> a
id
safeSymShiftR :: (MonadError e m, UnionLike m) => a -> a -> m a
safeSymShiftR = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymShiftR' e -> e
forall a. a -> a
id
safeSymShiftL' ::
(MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
safeSymShiftR' ::
(MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
safeSymStrictShiftL :: (MonadError e m, UnionLike m) => a -> a -> m a
safeSymStrictShiftL = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymStrictShiftL' e -> e
forall a. a -> a
id
safeSymStrictShiftR :: (MonadError e m, UnionLike m) => a -> a -> m a
safeSymStrictShiftR = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymShift e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymStrictShiftR' e -> e
forall a. a -> a
id
safeSymStrictShiftL' ::
(MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
safeSymStrictShiftR' ::
(MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
{-# MINIMAL
safeSymShiftL',
safeSymShiftR',
safeSymStrictShiftL',
safeSymStrictShiftR'
#-}
safeSymShiftLConcreteNum ::
(MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) =>
e ->
Bool ->
a ->
a ->
m a
safeSymShiftLConcreteNum :: forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum e
e Bool
_ a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftLConcreteNum e
e Bool
allowLargeShiftAmount a
a a
s
| (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s :: Integer) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) =
if Bool
allowLargeShiftAmount then a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
0 else e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftLConcreteNum e
_ Bool
_ a
a a
s = a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftL a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
safeSymShiftRConcreteNum ::
(MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) =>
e ->
Bool ->
a ->
a ->
m a
safeSymShiftRConcreteNum :: forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum e
e Bool
_ a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftRConcreteNum e
e Bool
allowLargeShiftAmount a
a a
s
| (a -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s :: Integer) Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a) =
if Bool
allowLargeShiftAmount then a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn a
0 else e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymShiftRConcreteNum e
_ Bool
_ a
a a
s = a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
shiftR a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
#define SAFE_SYM_SHIFT_CONCRETE(T) \
instance SafeSymShift ArithException T where \
safeSymShiftL' f = safeSymShiftLConcreteNum (f Overflow) True; \
safeSymShiftR' f = safeSymShiftRConcreteNum (f Overflow) True; \
safeSymStrictShiftL' f = safeSymShiftLConcreteNum (f Overflow) False; \
safeSymStrictShiftR' f = safeSymShiftRConcreteNum (f Overflow) False
#if 1
SAFE_SYM_SHIFT_CONCRETE(Word8)
SAFE_SYM_SHIFT_CONCRETE(Word16)
SAFE_SYM_SHIFT_CONCRETE(Word32)
SAFE_SYM_SHIFT_CONCRETE(Word64)
SAFE_SYM_SHIFT_CONCRETE(Word)
SAFE_SYM_SHIFT_CONCRETE(Int8)
SAFE_SYM_SHIFT_CONCRETE(Int16)
SAFE_SYM_SHIFT_CONCRETE(Int32)
SAFE_SYM_SHIFT_CONCRETE(Int64)
SAFE_SYM_SHIFT_CONCRETE(Int)
#endif
instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (WordN n) where
safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymShiftL' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymShiftR' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymStrictShiftL' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False
safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymStrictShiftR' ArithException -> e'
f = e' -> Bool -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False
instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (IntN n) where
safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymShiftL' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymShiftR' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
True
safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymStrictShiftL' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftLConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False
safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymStrictShiftR' ArithException -> e'
f = e' -> Bool -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
Mergeable a) =>
e -> Bool -> a -> a -> m a
safeSymShiftRConcreteNum (ArithException -> e'
f ArithException
Overflow) Bool
False
instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (SymWordN n) where
safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymShiftL' ArithException -> e'
_ (SymWordN Term (WordN n)
a) (SymWordN Term (WordN n)
s) =
SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (WordN n)
a Term (WordN n)
s
safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymShiftR' ArithException -> e'
_ (SymWordN Term (WordN n)
a) (SymWordN Term (WordN n)
s) =
SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (WordN n)
a Term (WordN n)
s
safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymStrictShiftL' ArithException -> e'
f a :: SymWordN n
a@(SymWordN Term (WordN n)
ta) s :: SymWordN n
s@(SymWordN Term (WordN n)
ts) =
SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymWordN n
s SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= Int -> SymWordN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymWordN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymWordN n
a))
(e' -> m (SymWordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymWordN n)) -> e' -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
(SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (WordN n)
ta Term (WordN n)
ts)
safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymStrictShiftR' ArithException -> e'
f a :: SymWordN n
a@(SymWordN Term (WordN n)
ta) s :: SymWordN n
s@(SymWordN Term (WordN n)
ts) =
SymBool -> m (SymWordN n) -> m (SymWordN n) -> m (SymWordN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymWordN n
s SymWordN n -> SymWordN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= Int -> SymWordN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymWordN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymWordN n
a))
(e' -> m (SymWordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymWordN n)) -> e' -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
(SymWordN n -> m (SymWordN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (WordN n)
ta Term (WordN n)
ts)
instance (KnownNat n, 1 <= n) => SafeSymShift ArithException (SymIntN n) where
safeSymShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymShiftL' ArithException -> e'
f (SymIntN Term (IntN n)
a) ss :: SymIntN n
ss@(SymIntN Term (IntN n)
s) =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
ss SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
(e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
(SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (IntN n)
a Term (IntN n)
s)
safeSymShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymShiftR' ArithException -> e'
f (SymIntN Term (IntN n)
a) ss :: SymIntN n
ss@(SymIntN Term (IntN n)
s) =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
ss SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
(e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
(SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (IntN n)
a Term (IntN n)
s)
safeSymStrictShiftL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymStrictShiftL' ArithException -> e'
f a :: SymIntN n
a@(SymIntN Term (IntN n)
ta) s :: SymIntN n
s@(SymIntN Term (IntN n)
ts) =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.|| (SymIntN n
bs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
bs))
(e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
(SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftLeftTerm Term (IntN n)
ta Term (IntN n)
ts)
where
bs :: SymIntN n
bs = Int -> SymIntN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymIntN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymIntN n
a)
safeSymStrictShiftR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymStrictShiftR' ArithException -> e'
f a :: SymIntN n
a@(SymIntN Term (IntN n)
ta) s :: SymIntN n
s@(SymIntN Term (IntN n)
ts) =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.|| (SymIntN n
bs SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
0 SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& SymIntN n
s SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.>= SymIntN n
bs))
(e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
(SymIntN n -> m (SymIntN n)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymShift a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalShiftRightTerm Term (IntN n)
ta Term (IntN n)
ts)
where
bs :: SymIntN n
bs = Int -> SymIntN n
forall a b. (Integral a, Num b) => a -> b
fromIntegral (SymIntN n -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize SymIntN n
a)