{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Grisette.Core.Data.Class.SafeSymRotate (SafeSymRotate (..)) where

import Control.Exception (ArithException (Overflow))
import Control.Monad.Error.Class (MonadError)
import Data.Bits (Bits (rotateL, rotateR), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Core.Control.Monad.Union (MonadUnion)
import Grisette.Core.Data.BV (IntN, WordN)
import Grisette.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Core.Data.Class.SOrd (SOrd ((.<)))
import Grisette.Core.Data.Class.SimpleMergeable (UnionLike, mrgIf)
import Grisette.Core.Data.Class.SymRotate (SymRotate)
import Grisette.IR.SymPrim.Data.Prim.PartialEval.Bits
  ( pevalRotateLeftTerm,
    pevalRotateRightTerm,
  )
import Grisette.IR.SymPrim.Data.SymPrim
  ( SymIntN (SymIntN),
    SymWordN (SymWordN),
  )
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)

class (SymRotate a) => SafeSymRotate e a | a -> e where
  safeSymRotateL :: (MonadError e m, UnionLike m) => a -> a -> m a
  safeSymRotateL = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymRotate e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymRotateL' e -> e
forall a. a -> a
id
  safeSymRotateR :: (MonadError e m, UnionLike m) => a -> a -> m a
  safeSymRotateR = (e -> e) -> a -> a -> m a
forall e a e' (m :: * -> *).
(SafeSymRotate e a, MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(e -> e') -> a -> a -> m a
safeSymRotateR' e -> e
forall a. a -> a
id
  safeSymRotateL' ::
    (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
  safeSymRotateR' ::
    (MonadError e' m, UnionLike m) => (e -> e') -> a -> a -> m a
  {-# MINIMAL safeSymRotateL', safeSymRotateR' #-}

-- | This function handles the case when the shift amount is out the range of
-- `Int` correctly.
safeSymRotateLConcreteNum ::
  (MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) =>
  e ->
  a ->
  a ->
  m a
safeSymRotateLConcreteNum :: forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> a -> a -> m a
safeSymRotateLConcreteNum e
e a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymRotateLConcreteNum e
_ a
a a
s =
  a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
rotateL a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Integral a => a -> a -> a
`rem` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s))

-- | This function handles the case when the shift amount is out the range of
-- `Int` correctly.
safeSymRotateRConcreteNum ::
  (MonadError e m, MonadUnion m, Integral a, FiniteBits a, Mergeable a) =>
  e ->
  a ->
  a ->
  m a
safeSymRotateRConcreteNum :: forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> a -> a -> m a
safeSymRotateRConcreteNum e
e a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = e -> m a
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError e
e
safeSymRotateRConcreteNum e
_ a
a a
s =
  a -> m a
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
rotateR a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Integral a => a -> a -> a
`rem` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s))

#define SAFE_SYM_ROTATE_CONCRETE(T) \
  instance SafeSymRotate ArithException T where \
    safeSymRotateL' f = safeSymRotateLConcreteNum (f Overflow); \
    safeSymRotateR' f = safeSymRotateRConcreteNum (f Overflow) \

#if 1
SAFE_SYM_ROTATE_CONCRETE(Word8)
SAFE_SYM_ROTATE_CONCRETE(Word16)
SAFE_SYM_ROTATE_CONCRETE(Word32)
SAFE_SYM_ROTATE_CONCRETE(Word64)
SAFE_SYM_ROTATE_CONCRETE(Word)
SAFE_SYM_ROTATE_CONCRETE(Int8)
SAFE_SYM_ROTATE_CONCRETE(Int16)
SAFE_SYM_ROTATE_CONCRETE(Int32)
SAFE_SYM_ROTATE_CONCRETE(Int64)
SAFE_SYM_ROTATE_CONCRETE(Int)
#endif

instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (WordN n) where
  safeSymRotateL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymRotateL' ArithException -> e'
f = e' -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> a -> a -> m a
safeSymRotateLConcreteNum (ArithException -> e'
f ArithException
Overflow)
  safeSymRotateR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> WordN n -> WordN n -> m (WordN n)
safeSymRotateR' ArithException -> e'
f = e' -> WordN n -> WordN n -> m (WordN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> a -> a -> m a
safeSymRotateRConcreteNum (ArithException -> e'
f ArithException
Overflow)

instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (IntN n) where
  safeSymRotateL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymRotateL' ArithException -> e'
f = e' -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> a -> a -> m a
safeSymRotateLConcreteNum (ArithException -> e'
f ArithException
Overflow)
  safeSymRotateR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> IntN n -> IntN n -> m (IntN n)
safeSymRotateR' ArithException -> e'
f = e' -> IntN n -> IntN n -> m (IntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Integral a, FiniteBits a,
 Mergeable a) =>
e -> a -> a -> m a
safeSymRotateRConcreteNum (ArithException -> e'
f ArithException
Overflow)

instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (SymWordN n) where
  safeSymRotateL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymRotateL' ArithException -> e'
_ (SymWordN Term (WordN n)
ta) (SymWordN Term (WordN n)
tr) =
    SymWordN n -> m (SymWordN n)
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalRotateLeftTerm Term (WordN n)
ta Term (WordN n)
tr
  safeSymRotateR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e')
-> SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymRotateR' ArithException -> e'
_ (SymWordN Term (WordN n)
ta) (SymWordN Term (WordN n)
tr) =
    SymWordN n -> m (SymWordN n)
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalRotateRightTerm Term (WordN n)
ta Term (WordN n)
tr

instance (KnownNat n, 1 <= n) => SafeSymRotate ArithException (SymIntN n) where
  safeSymRotateL' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymRotateL' ArithException -> e'
f (SymIntN Term (IntN n)
ta) r :: SymIntN n
r@(SymIntN Term (IntN n)
tr) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
r SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
      (e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalRotateLeftTerm Term (IntN n)
ta Term (IntN n)
tr)
  safeSymRotateR' :: forall e' (m :: * -> *).
(MonadError e' m, UnionLike m) =>
(ArithException -> e') -> SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymRotateR' ArithException -> e'
f (SymIntN Term (IntN n)
ta) r :: SymIntN n
r@(SymIntN Term (IntN n)
tr) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionLike u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
r SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
      (e' -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, MonadUnion m, Mergeable a) =>
e -> m a
mrgThrowError (e' -> m (SymIntN n)) -> e' -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ ArithException -> e'
f ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall (u :: * -> *) a. (MonadUnion u, Mergeable a) => a -> u a
mrgReturn (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall a.
(Integral a, SymRotate a, FiniteBits a, SupportedPrim a) =>
Term a -> Term a -> Term a
pevalRotateRightTerm Term (IntN n)
ta Term (IntN n)
tr)