{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Grisette.Internal.Core.Data.Class.SafeSymRotate (SafeSymRotate (..)) where
import Control.Exception (ArithException (Overflow))
import Control.Monad.Error.Class (MonadError)
import Data.Bits (Bits (rotateL, rotateR), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SOrd (SOrd ((.<)))
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.Prim.Term
( PEvalRotateTerm
( pevalRotateLeftTerm,
pevalRotateRightTerm
),
)
import Grisette.Internal.SymPrim.SymBV
( SymIntN (SymIntN),
SymWordN (SymWordN),
)
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)
class (MonadError e m, TryMerge m, Mergeable a) => SafeSymRotate e a m where
safeSymRotateL :: a -> a -> m a
safeSymRotateR :: a -> a -> m a
safeSymRotateLConcreteNum ::
( MonadError ArithException m,
TryMerge m,
Integral a,
FiniteBits a,
Mergeable a
) =>
a ->
a ->
m a
safeSymRotateLConcreteNum :: forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
Mergeable a) =>
a -> a -> m a
safeSymRotateLConcreteNum a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = ArithException -> m a
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow
safeSymRotateLConcreteNum a
a a
s =
a -> m a
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
rotateL a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Integral a => a -> a -> a
`rem` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s))
safeSymRotateRConcreteNum ::
( MonadError ArithException m,
TryMerge m,
Integral a,
FiniteBits a,
Mergeable a
) =>
a ->
a ->
m a
safeSymRotateRConcreteNum :: forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
Mergeable a) =>
a -> a -> m a
safeSymRotateRConcreteNum a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = ArithException -> m a
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow
safeSymRotateRConcreteNum a
a a
s =
a -> m a
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
rotateR a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Integral a => a -> a -> a
`rem` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s))
#define SAFE_SYM_ROTATE_CONCRETE(T) \
instance (MonadError ArithException m, TryMerge m) => \
SafeSymRotate ArithException T m where \
safeSymRotateL = safeSymRotateLConcreteNum; \
safeSymRotateR = safeSymRotateRConcreteNum \
#if 1
SAFE_SYM_ROTATE_CONCRETE(Word8)
SAFE_SYM_ROTATE_CONCRETE(Word16)
SAFE_SYM_ROTATE_CONCRETE(Word32)
SAFE_SYM_ROTATE_CONCRETE(Word64)
SAFE_SYM_ROTATE_CONCRETE(Word)
SAFE_SYM_ROTATE_CONCRETE(Int8)
SAFE_SYM_ROTATE_CONCRETE(Int16)
SAFE_SYM_ROTATE_CONCRETE(Int32)
SAFE_SYM_ROTATE_CONCRETE(Int64)
SAFE_SYM_ROTATE_CONCRETE(Int)
#endif
instance
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
SafeSymRotate ArithException (WordN n) m
where
safeSymRotateL :: WordN n -> WordN n -> m (WordN n)
safeSymRotateL = WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
Mergeable a) =>
a -> a -> m a
safeSymRotateLConcreteNum
safeSymRotateR :: WordN n -> WordN n -> m (WordN n)
safeSymRotateR = WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
Mergeable a) =>
a -> a -> m a
safeSymRotateRConcreteNum
instance
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
SafeSymRotate ArithException (IntN n) m
where
safeSymRotateL :: IntN n -> IntN n -> m (IntN n)
safeSymRotateL = IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
Mergeable a) =>
a -> a -> m a
safeSymRotateLConcreteNum
safeSymRotateR :: IntN n -> IntN n -> m (IntN n)
safeSymRotateR = IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
Mergeable a) =>
a -> a -> m a
safeSymRotateRConcreteNum
instance
(MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
SafeSymRotate ArithException (SymWordN n) m
where
safeSymRotateL :: SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymRotateL (SymWordN Term (WordN n)
ta) (SymWordN Term (WordN n)
tr) =
SymWordN n -> m (SymWordN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateLeftTerm Term (WordN n)
ta Term (WordN n)
tr
safeSymRotateR :: SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymRotateR (SymWordN Term (WordN n)
ta) (SymWordN Term (WordN n)
tr) =
SymWordN n -> m (SymWordN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateRightTerm Term (WordN n)
ta Term (WordN n)
tr
instance
(MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
SafeSymRotate ArithException (SymIntN n) m
where
safeSymRotateL :: SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymRotateL (SymIntN Term (IntN n)
ta) r :: SymIntN n
r@(SymIntN Term (IntN n)
tr) =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
r SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
(ArithException -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow)
(SymIntN n -> m (SymIntN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateLeftTerm Term (IntN n)
ta Term (IntN n)
tr)
safeSymRotateR :: SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymRotateR (SymIntN Term (IntN n)
ta) r :: SymIntN n
r@(SymIntN Term (IntN n)
tr) =
SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
(SymIntN n
r SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
(ArithException -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow)
(SymIntN n -> m (SymIntN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateRightTerm Term (IntN n)
ta Term (IntN n)
tr)