{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module What4.Utils.Arithmetic
  ( 
    isPow2
  , lg
  , lgCeil
  , nextMultiple
  , nextPow2Multiple
  , tryIntSqrt
  , tryRationalSqrt
  , roundAway
  , ctz
  , clz
  , rotateLeft
  , rotateRight
  ) where
import Control.Exception (assert)
import Data.Bits (Bits(..))
import Data.Ratio
import Data.Parameterized.NatRepr
isPow2 :: (Bits a, Num a) => a -> Bool
isPow2 x = x .&. (x-1) == 0
lg :: (Bits a, Num a, Ord a) => a -> Int
lg i0 | i0 > 0 = go 0 (i0 `shiftR` 1)
      | otherwise = error "lg given number that is not positive."
  where go r 0 = r
        go r n = go (r+1) (n `shiftR` 1)
lgCeil :: (Bits a, Num a, Ord a) => a -> Int
lgCeil 0 = 0
lgCeil 1 = 0
lgCeil i | i > 1 = 1 + lg (i-1)
         | otherwise = error "lgCeil given number that is not positive."
ctz :: NatRepr w -> Integer -> Integer
ctz w x = go 0
 where
 go !i
   | i < toInteger (natValue w) && testBit x (fromInteger i) == False = go (i+1)
   | otherwise = i
clz :: NatRepr w -> Integer -> Integer
clz w x = go 0
 where
 go !i
   | i < toInteger (natValue w) && testBit x (widthVal w - fromInteger i - 1) == False = go (i+1)
   | otherwise = i
rotateRight ::
  NatRepr w  ->
  Integer  ->
  Integer  ->
  Integer
rotateRight w x n = xor (shiftR x' n') (toUnsigned w (shiftL x' (widthVal w - n')))
 where
 x' = toUnsigned w x
 n' = fromInteger (n `rem` intValue w)
rotateLeft ::
  NatRepr w  ->
  Integer  ->
  Integer  ->
  Integer
rotateLeft w x n = xor (shiftR x' (widthVal w - n')) (toUnsigned w (shiftL x' n'))
 where
 x' = toUnsigned w x
 n' = fromInteger (n `rem` intValue w)
nextMultiple :: Integral a => a -> a -> a
nextMultiple x y = ((y + x - 1) `div` x) * x
nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
nextPow2Multiple x n | x >= 0 && n >= 0 = ((x+2^n -1) `shiftR` n) `shiftL` n
                     | otherwise = error "nextPow2Multiple given negative value."
tryIntSqrt :: Integer -> Maybe Integer
tryIntSqrt 0 = return 0
tryIntSqrt 1 = return 1
tryIntSqrt 2 = Nothing
tryIntSqrt 3 = Nothing
tryIntSqrt n = assert (n >= 4) $ go (n `shiftR` 1)
  where go x | x2 < n  = Nothing   
             | x2 == n = return x' 
             | True    = go x'     
          where 
                x' = (x + n `div` x) `div` 2
                x2 = x' * x'
tryRationalSqrt :: Rational -> Maybe Rational
tryRationalSqrt r = do
  (%) <$> tryIntSqrt (numerator   r)
      <*> tryIntSqrt (denominator r)
roundAway :: (RealFrac a) => a -> Integer
roundAway r = truncate (r + signum r * 0.5)