{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.SymRotate
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SymRotate
  ( SymRotate (..),
    DefaultFiniteBitsSymRotate (..),
  )
where

import Data.Bits (Bits (isSigned, rotate), FiniteBits (finiteBitSize))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.Word (Word16, Word32, Word64, Word8)

-- | The `symRotate` is similar to `rotate`, but accepts the type itself instead
-- of `Int` for the rotate amount. The function works on all inputs, including
-- the rotate amounts that are beyond the bit width of the value.
--
-- The `symRotateNegated` function rotates to the opposite direction of
-- `symRotate`. This function is introduced to handle the asymmetry of the range
-- of values.
class (Bits a) => SymRotate a where
  symRotate :: a -> a -> a
  symRotateNegated :: a -> a -> a

instance SymRotate Int where
  symRotate :: Int -> Int -> Int
symRotate = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
rotate
  symRotateNegated :: Int -> Int -> Int
symRotateNegated Int
a Int
s
    | Int
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
forall a. Bounded a => a
minBound = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
rotate Int
a (-Int
s)
    | Bool
otherwise = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
rotate Int
a (-(Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int
s))

-- | A newtype wrapper. Use this to derive `SymRotate` for types that have
-- `FiniteBits` instances.
newtype DefaultFiniteBitsSymRotate a = DefaultFiniteBitsSymRotate
  { forall a. DefaultFiniteBitsSymRotate a -> a
unDefaultFiniteBitsSymRotate :: a
  }
  deriving newtype (DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> Bool
(DefaultFiniteBitsSymRotate a
 -> DefaultFiniteBitsSymRotate a -> Bool)
-> (DefaultFiniteBitsSymRotate a
    -> DefaultFiniteBitsSymRotate a -> Bool)
-> Eq (DefaultFiniteBitsSymRotate a)
forall a.
Eq a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a.
Eq a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> Bool
== :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> Bool
$c/= :: forall a.
Eq a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> Bool
/= :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> Bool
Eq, Eq (DefaultFiniteBitsSymRotate a)
DefaultFiniteBitsSymRotate a
Eq (DefaultFiniteBitsSymRotate a) =>
(DefaultFiniteBitsSymRotate a
 -> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> DefaultFiniteBitsSymRotate a
-> (Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a -> Int -> Bool)
-> (DefaultFiniteBitsSymRotate a -> Maybe Int)
-> (DefaultFiniteBitsSymRotate a -> Int)
-> (DefaultFiniteBitsSymRotate a -> Bool)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a
    -> Int -> DefaultFiniteBitsSymRotate a)
-> (DefaultFiniteBitsSymRotate a -> Int)
-> Bits (DefaultFiniteBitsSymRotate a)
Int -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate a -> Bool
DefaultFiniteBitsSymRotate a -> Int
DefaultFiniteBitsSymRotate a -> Maybe Int
DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate a -> Int -> Bool
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
forall a.
Eq a =>
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> a
-> (Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> Bool)
-> (a -> Maybe Int)
-> (a -> Int)
-> (a -> Bool)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int -> a)
-> (a -> Int)
-> Bits a
forall a. Bits a => Eq (DefaultFiniteBitsSymRotate a)
forall a. Bits a => DefaultFiniteBitsSymRotate a
forall a. Bits a => Int -> DefaultFiniteBitsSymRotate a
forall a. Bits a => DefaultFiniteBitsSymRotate a -> Bool
forall a. Bits a => DefaultFiniteBitsSymRotate a -> Int
forall a. Bits a => DefaultFiniteBitsSymRotate a -> Maybe Int
forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
forall a. Bits a => DefaultFiniteBitsSymRotate a -> Int -> Bool
forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
forall a.
Bits a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
$c.&. :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
.&. :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
$c.|. :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
.|. :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
$cxor :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
xor :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
$ccomplement :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
complement :: DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
$cshift :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
shift :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$crotate :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
rotate :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$czeroBits :: forall a. Bits a => DefaultFiniteBitsSymRotate a
zeroBits :: DefaultFiniteBitsSymRotate a
$cbit :: forall a. Bits a => Int -> DefaultFiniteBitsSymRotate a
bit :: Int -> DefaultFiniteBitsSymRotate a
$csetBit :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
setBit :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$cclearBit :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
clearBit :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$ccomplementBit :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
complementBit :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$ctestBit :: forall a. Bits a => DefaultFiniteBitsSymRotate a -> Int -> Bool
testBit :: DefaultFiniteBitsSymRotate a -> Int -> Bool
$cbitSizeMaybe :: forall a. Bits a => DefaultFiniteBitsSymRotate a -> Maybe Int
bitSizeMaybe :: DefaultFiniteBitsSymRotate a -> Maybe Int
$cbitSize :: forall a. Bits a => DefaultFiniteBitsSymRotate a -> Int
bitSize :: DefaultFiniteBitsSymRotate a -> Int
$cisSigned :: forall a. Bits a => DefaultFiniteBitsSymRotate a -> Bool
isSigned :: DefaultFiniteBitsSymRotate a -> Bool
$cshiftL :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
shiftL :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$cunsafeShiftL :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
unsafeShiftL :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$cshiftR :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
shiftR :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$cunsafeShiftR :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
unsafeShiftR :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$crotateL :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
rotateL :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$crotateR :: forall a.
Bits a =>
DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
rotateR :: DefaultFiniteBitsSymRotate a -> Int -> DefaultFiniteBitsSymRotate a
$cpopCount :: forall a. Bits a => DefaultFiniteBitsSymRotate a -> Int
popCount :: DefaultFiniteBitsSymRotate a -> Int
Bits)

instance
  (Integral a, FiniteBits a) =>
  SymRotate (DefaultFiniteBitsSymRotate a)
  where
  symRotate :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
symRotate (DefaultFiniteBitsSymRotate a
a) (DefaultFiniteBitsSymRotate a
s)
    | a -> Bool
forall a. Bits a => a -> Bool
isSigned a
a = a -> DefaultFiniteBitsSymRotate a
forall a. a -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate (a -> DefaultFiniteBitsSymRotate a)
-> a -> DefaultFiniteBitsSymRotate a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
symRotateSigned a
a a
s
    | Bool
otherwise = a -> DefaultFiniteBitsSymRotate a
forall a. a -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate (a -> DefaultFiniteBitsSymRotate a)
-> a -> DefaultFiniteBitsSymRotate a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
symRotateUnsigned a
a a
s
    where
      symRotateUnsigned :: a -> a -> a
      symRotateUnsigned :: a -> a -> a
symRotateUnsigned a
a a
s =
        a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a
s a -> a -> a
forall a. Integral a => a -> a -> a
`mod` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a)))
      symRotateSigned :: a -> a -> a
      symRotateSigned :: a -> a -> a
symRotateSigned a
a a
s
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a
a
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
s Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
        | Bool
otherwise =
            a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a
s a -> a -> a
forall a. Integral a => a -> a -> a
`mod` Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a)))
  symRotateNegated :: DefaultFiniteBitsSymRotate a
-> DefaultFiniteBitsSymRotate a -> DefaultFiniteBitsSymRotate a
symRotateNegated (DefaultFiniteBitsSymRotate a
a) (DefaultFiniteBitsSymRotate a
s)
    | a -> Bool
forall a. Bits a => a -> Bool
isSigned a
a = a -> DefaultFiniteBitsSymRotate a
forall a. a -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate (a -> DefaultFiniteBitsSymRotate a)
-> a -> DefaultFiniteBitsSymRotate a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
symRotateSigned a
a a
s
    | Bool
otherwise = a -> DefaultFiniteBitsSymRotate a
forall a. a -> DefaultFiniteBitsSymRotate a
DefaultFiniteBitsSymRotate (a -> DefaultFiniteBitsSymRotate a)
-> a -> DefaultFiniteBitsSymRotate a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
symRotateUnsigned a
a a
s
    where
      bs :: a
bs = Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a)
      smodbs :: a
smodbs = a
s a -> a -> a
forall a. Integral a => a -> a -> a
`mod` a
bs
      symRotateUnsigned :: a -> a -> a
      symRotateUnsigned :: a -> a -> a
symRotateUnsigned a
a a
_ =
        a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a
bs a -> a -> a
forall a. Num a => a -> a -> a
- a
smodbs))
      symRotateSigned :: a -> a -> a
      symRotateSigned :: a -> a -> a
symRotateSigned a
a a
s
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = a
a
        | a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
2 = a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (-a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
s)
        | Bool
otherwise =
            if a
smodbs a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0
              then a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a
bs a -> a -> a
forall a. Num a => a -> a -> a
- a
smodbs))
              else a -> Int -> a
forall a. Bits a => a -> Int -> a
rotate a
a (a -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (-a
smodbs))

deriving via (DefaultFiniteBitsSymRotate Int8) instance SymRotate Int8

deriving via (DefaultFiniteBitsSymRotate Int16) instance SymRotate Int16

deriving via (DefaultFiniteBitsSymRotate Int32) instance SymRotate Int32

deriving via (DefaultFiniteBitsSymRotate Int64) instance SymRotate Int64

deriving via (DefaultFiniteBitsSymRotate Word8) instance SymRotate Word8

deriving via (DefaultFiniteBitsSymRotate Word16) instance SymRotate Word16

deriving via (DefaultFiniteBitsSymRotate Word32) instance SymRotate Word32

deriving via (DefaultFiniteBitsSymRotate Word64) instance SymRotate Word64

deriving via (DefaultFiniteBitsSymRotate Word) instance SymRotate Word