module Feldspar.Core.Functions.Integral where

import qualified Prelude
import Data.Int
import Data.Word

import Feldspar.Prelude
import Feldspar.Range
import Feldspar.Core.Types
import Feldspar.Core.Representation
import Feldspar.Core.Constructs
import Feldspar.Core.Functions.Logic
import Feldspar.Core.Functions.Eq
import Feldspar.Core.Functions.Ord
import Feldspar.Core.Functions.Num
import Feldspar.Core.Functions.Bits

-- | Redefinition of the standard 'Prelude.Integral' class for Feldspar
class (Numeric a, BoundedInt a, Bits a, Ord a) => Integral a where
  quot :: Data a -> Data a -> Data a
  quot =  defaultQuot
  rem  :: Data a -> Data a -> Data a
  div  :: Data a -> Data a -> Data a
  div  =  defaultDiv
  mod  :: Data a -> Data a -> Data a
  mod  =  defaultMod
  (^)  :: Data a -> Data a -> Data a
  (^)  =  optExp fullProp

-- TODO Should (^) really be in this class? The standard function has type
--
--     (Num a, Integral b) => a -> b -> a

defaultQuot :: Integral a => Data a -> Data a -> Data a
defaultQuot = function2 "quot" fullProp Prelude.quot

optQuot :: (Integral a, BoundedInt a, Size a ~ Range a) =>
          Data a -> Data a -> Data a
optQuot x y = function2 "quot" rangeQuot Prelude.quot x y

defaultDiv :: Integral a => Data a -> Data a -> Data a
defaultDiv x y = rem x y /= 0 && (x > 0 && y < 0 || x < 0 && y > 0) ?
                   (quotxy - 1, quotxy)
  where
    quotxy = quot x y

defaultMod :: Integral a => Data a -> Data a -> Data a
defaultMod x y = remxy /= 0 && (x > 0 && y < 0 || x < 0 && y > 0) ?
                   (remxy + y, remxy)
  where
    remxy = rem x y

optRem :: (Integral a, BoundedInt a, Size a ~ Range a) =>
          Data a -> Data a -> Data a
optRem x y
    -- -- | abs rx `rangeLess` abs ry = x
        -- This optimization is invalid if x == (-128) and 'a' is Int8
    | otherwise                 = function2 "rem" rangeRem Prelude.rem x y
    where rx = dataSize x
          ry = dataSize y
  -- TODO Use as default implementation of 'rem', when equality is allowed as
  --      super class constraint (i.e. Size a ~ Range a).

optMod :: (Integral a, BoundedInt b, Size a ~ Range b) =>
          Data a -> Data a -> Data a
optMod x y = cap (rangeMod rx ry) $
             remxy /= 0 && (x > 0 && y < 0 || x < 0 && y > 0) ?
             (remxy + y, remxy)
  where remxy = rem x y
        rx    = dataSize x
        ry    = dataSize y

optExp :: Integral a =>
          (Size a -> Size a -> Size a)
       -> Data a -> Data a -> Data a
optExp prop m e = case (viewLiteral m, viewLiteral e) of
               (Just 1,_) -> value 1
               (_,Just 1) -> m
               (_,Just 0) -> value 1
               _          -> function2 "(^)" prop (Prelude.^) m e

optSignedExp :: (Integral a, Signed a, BoundedInt b, Size a ~ Range b) =>
                Data a -> Data a -> Data a
optSignedExp m e = case viewLiteral m of
                   -- From Bit Twiddling Hacks
                   -- "Conditionally negate a value without branching"
                   -- Here we negate the value 1 if isOdd is true i.e. when e is
                   -- and odd number
                     Just (-1) -> cap (range (-1) 1) $
                                    let isOdd = e .&. 1
                                    in (1 `xor` (negate isOdd)) + isOdd
                     _ -> optExp rangeExp m e

instance Integral Word8 where
  div = optQuot
  rem = optRem
  mod = rem

instance Integral Int8 where
  rem = optRem
  mod = optMod
  (^) = optSignedExp

instance Integral Word16 where
  div = optQuot
  rem = optRem
  mod = rem

instance Integral Int16 where
  rem = optRem
  mod = optMod
  (^) = optSignedExp

instance Integral Word32 where
  div = optQuot
  rem = optRem
  mod = rem

instance Integral Int32 where
  rem = optRem
  mod = optMod
  (^) = optSignedExp

instance Integral DefaultWord where
  div = optQuot
  rem = optRem
  mod = rem

instance Integral DefaultInt where
  rem = optRem
  mod = optMod
  (^) = optSignedExp