{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SafeSymRotate
-- Copyright   :   (c) Sirui Lu 2023-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SafeSymRotate (SafeSymRotate (..)) where

import Control.Exception (ArithException (Overflow))
import Control.Monad.Error.Class (MonadError)
import Data.Bits (Bits (rotateL, rotateR), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeLits (KnownNat, type (<=))
import Grisette.Internal.Core.Control.Monad.Union (MonadUnion)
import Grisette.Internal.Core.Data.Class.Mergeable (Mergeable)
import Grisette.Internal.Core.Data.Class.SOrd (SOrd ((.<)))
import Grisette.Internal.Core.Data.Class.SimpleMergeable (mrgIf)
import Grisette.Internal.Core.Data.Class.TryMerge (TryMerge)
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.Prim.Term
  ( PEvalRotateTerm
      ( pevalRotateLeftTerm,
        pevalRotateRightTerm
      ),
  )
import Grisette.Internal.SymPrim.SymBV
  ( SymIntN (SymIntN),
    SymWordN (SymWordN),
  )
import Grisette.Lib.Control.Monad (mrgReturn)
import Grisette.Lib.Control.Monad.Except (mrgThrowError)

-- | Safe rotation operations. The operators will reject negative shift amounts.
class (MonadError e m, TryMerge m, Mergeable a) => SafeSymRotate e a m where
  safeSymRotateL :: a -> a -> m a
  safeSymRotateR :: a -> a -> m a

-- | This function handles the case when the shift amount is out the range of
-- `Int` correctly.
safeSymRotateLConcreteNum ::
  ( MonadError ArithException m,
    TryMerge m,
    Integral a,
    FiniteBits a,
    Mergeable a
  ) =>
  a ->
  a ->
  m a
safeSymRotateLConcreteNum :: forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
 Mergeable a) =>
a -> a -> m a
safeSymRotateLConcreteNum a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = ArithException -> m a
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow
safeSymRotateLConcreteNum a
a a
s =
  a -> m a
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
rotateL a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Integral a => a -> a -> a
`rem` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s))

-- | This function handles the case when the shift amount is out the range of
-- `Int` correctly.
safeSymRotateRConcreteNum ::
  ( MonadError ArithException m,
    TryMerge m,
    Integral a,
    FiniteBits a,
    Mergeable a
  ) =>
  a ->
  a ->
  m a
safeSymRotateRConcreteNum :: forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
 Mergeable a) =>
a -> a -> m a
safeSymRotateRConcreteNum a
_ a
s | a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = ArithException -> m a
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow
safeSymRotateRConcreteNum a
a a
s =
  a -> m a
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ a -> Int -> a
forall a. Bits a => a -> Int -> a
rotateR a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Integral a => a -> a -> a
`rem` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s))

#define SAFE_SYM_ROTATE_CONCRETE(T) \
  instance (MonadError ArithException m, TryMerge m) => \
    SafeSymRotate ArithException T m where \
    safeSymRotateL = safeSymRotateLConcreteNum; \
    safeSymRotateR = safeSymRotateRConcreteNum \

#if 1
SAFE_SYM_ROTATE_CONCRETE(Word8)
SAFE_SYM_ROTATE_CONCRETE(Word16)
SAFE_SYM_ROTATE_CONCRETE(Word32)
SAFE_SYM_ROTATE_CONCRETE(Word64)
SAFE_SYM_ROTATE_CONCRETE(Word)
SAFE_SYM_ROTATE_CONCRETE(Int8)
SAFE_SYM_ROTATE_CONCRETE(Int16)
SAFE_SYM_ROTATE_CONCRETE(Int32)
SAFE_SYM_ROTATE_CONCRETE(Int64)
SAFE_SYM_ROTATE_CONCRETE(Int)
#endif

instance
  (MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
  SafeSymRotate ArithException (WordN n) m
  where
  safeSymRotateL :: WordN n -> WordN n -> m (WordN n)
safeSymRotateL = WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
 Mergeable a) =>
a -> a -> m a
safeSymRotateLConcreteNum
  safeSymRotateR :: WordN n -> WordN n -> m (WordN n)
safeSymRotateR = WordN n -> WordN n -> m (WordN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
 Mergeable a) =>
a -> a -> m a
safeSymRotateRConcreteNum

instance
  (MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
  SafeSymRotate ArithException (IntN n) m
  where
  safeSymRotateL :: IntN n -> IntN n -> m (IntN n)
safeSymRotateL = IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
 Mergeable a) =>
a -> a -> m a
safeSymRotateLConcreteNum
  safeSymRotateR :: IntN n -> IntN n -> m (IntN n)
safeSymRotateR = IntN n -> IntN n -> m (IntN n)
forall (m :: * -> *) a.
(MonadError ArithException m, TryMerge m, Integral a, FiniteBits a,
 Mergeable a) =>
a -> a -> m a
safeSymRotateRConcreteNum

instance
  (MonadError ArithException m, TryMerge m, KnownNat n, 1 <= n) =>
  SafeSymRotate ArithException (SymWordN n) m
  where
  safeSymRotateL :: SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymRotateL (SymWordN Term (WordN n)
ta) (SymWordN Term (WordN n)
tr) =
    SymWordN n -> m (SymWordN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateLeftTerm Term (WordN n)
ta Term (WordN n)
tr
  safeSymRotateR :: SymWordN n -> SymWordN n -> m (SymWordN n)
safeSymRotateR (SymWordN Term (WordN n)
ta) (SymWordN Term (WordN n)
tr) =
    SymWordN n -> m (SymWordN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymWordN n -> m (SymWordN n)) -> SymWordN n -> m (SymWordN n)
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> SymWordN n
forall (n :: Nat). Term (WordN n) -> SymWordN n
SymWordN (Term (WordN n) -> SymWordN n) -> Term (WordN n) -> SymWordN n
forall a b. (a -> b) -> a -> b
$ Term (WordN n) -> Term (WordN n) -> Term (WordN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateRightTerm Term (WordN n)
ta Term (WordN n)
tr

instance
  (MonadError ArithException m, MonadUnion m, KnownNat n, 1 <= n) =>
  SafeSymRotate ArithException (SymIntN n) m
  where
  safeSymRotateL :: SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymRotateL (SymIntN Term (IntN n)
ta) r :: SymIntN n
r@(SymIntN Term (IntN n)
tr) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
r SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
      (ArithException -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateLeftTerm Term (IntN n)
ta Term (IntN n)
tr)
  safeSymRotateR :: SymIntN n -> SymIntN n -> m (SymIntN n)
safeSymRotateR (SymIntN Term (IntN n)
ta) r :: SymIntN n
r@(SymIntN Term (IntN n)
tr) =
    SymBool -> m (SymIntN n) -> m (SymIntN n) -> m (SymIntN n)
forall (u :: * -> *) a.
(UnionMergeable1 u, Mergeable a) =>
SymBool -> u a -> u a -> u a
mrgIf
      (SymIntN n
r SymIntN n -> SymIntN n -> SymBool
forall a. SOrd a => a -> a -> SymBool
.< SymIntN n
0)
      (ArithException -> m (SymIntN n)
forall e (m :: * -> *) a.
(MonadError e m, TryMerge m, Mergeable a) =>
e -> m a
mrgThrowError ArithException
Overflow)
      (SymIntN n -> m (SymIntN n)
forall (u :: * -> *) a. (MonadTryMerge u, Mergeable a) => a -> u a
mrgReturn (SymIntN n -> m (SymIntN n)) -> SymIntN n -> m (SymIntN n)
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> SymIntN n
forall (n :: Nat). Term (IntN n) -> SymIntN n
SymIntN (Term (IntN n) -> SymIntN n) -> Term (IntN n) -> SymIntN n
forall a b. (a -> b) -> a -> b
$ Term (IntN n) -> Term (IntN n) -> Term (IntN n)
forall t. PEvalRotateTerm t => Term t -> Term t -> Term t
pevalRotateRightTerm Term (IntN n)
ta Term (IntN n)
tr)