module CLaSH.Sized.Internal.Index
  ( 
    Index (..)
    
    
  , pack#
  , unpack#
    
  , eq#
  , neq#
    
  , lt#
  , ge#
  , gt#
  , le#
    
  , enumFrom#
  , enumFromThen#
  , enumFromTo#
  , enumFromThenTo#
    
  , maxBound#
    
  , (+#)
  , (-#)
  , (*#)
  , fromInteger#
    
  , plus#
  , minus#
  , times#
    
  , quot#
  , rem#
  , toInteger#
    
  , resize#
  )
where
import Control.DeepSeq            (NFData (..))
import Data.Data                  (Data)
import Data.Default               (Default (..))
import Data.Proxy                 (Proxy (..))
import Text.Read                  (Read (..), ReadPrec)
import Language.Haskell.TH        (TypeQ, appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift(..))
import GHC.TypeLits               (KnownNat, Nat, type (+), type (), type (*),
                                   natVal)
import GHC.TypeLits.Extra         (CLog)
import Test.QuickCheck.Arbitrary  (Arbitrary (..), CoArbitrary (..),
                                   arbitraryBoundedIntegral,
                                   coarbitraryIntegral, shrinkIntegral)
import CLaSH.Class.BitPack            (BitPack (..))
import CLaSH.Class.Num                (ExtendingNum (..))
import CLaSH.Class.Resize             (Resize (..))
import  CLaSH.Sized.Internal.BitVector (BitVector (BV))
newtype Index (n :: Nat) =
    
    
    I { unsafeToInteger :: Integer }
  deriving Data
instance NFData (Index n) where
  rnf (I i) = rnf i `seq` ()
  
  
  
instance KnownNat n => BitPack (Index n) where
  type BitSize (Index n) = CLog 2 n
  pack   = pack#
  unpack = unpack#
pack# :: Index n -> BitVector (CLog 2 n)
pack# (I i) = BV i
unpack# :: KnownNat n => BitVector (CLog 2 n) -> Index n
unpack# (BV i) = fromInteger_INLINE i
instance Eq (Index n) where
  (==) = eq#
  (/=) = neq#
eq# :: (Index n) -> (Index n) -> Bool
(I n) `eq#` (I m) = n == m
neq# :: (Index n) -> (Index n) -> Bool
(I n) `neq#` (I m) = n /= m
instance Ord (Index n) where
  (<)  = lt#
  (>=) = ge#
  (>)  = gt#
  (<=) = le#
lt#,ge#,gt#,le# :: Index n -> Index n -> Bool
lt# (I n) (I m) = n < m
ge# (I n) (I m) = n >= m
gt# (I n) (I m) = n > m
le# (I n) (I m) = n <= m
instance KnownNat n => Enum (Index n) where
  succ           = (+# fromInteger# 1)
  pred           = (-# fromInteger# 1)
  toEnum         = fromInteger# . toInteger
  fromEnum       = fromEnum . toInteger#
  enumFrom       = enumFrom#
  enumFromThen   = enumFromThen#
  enumFromTo     = enumFromTo#
  enumFromThenTo = enumFromThenTo#
enumFrom#       :: Index n -> [Index n]
enumFromThen#   :: Index n -> Index n -> [Index n]
enumFromTo#     :: Index n -> Index n -> [Index n]
enumFromThenTo# :: Index n -> Index n -> Index n -> [Index n]
enumFrom# x             = map I [unsafeToInteger x ..]
enumFromThen# x y       = map I [unsafeToInteger x, unsafeToInteger y ..]
enumFromTo# x y         = map I [unsafeToInteger x .. unsafeToInteger y]
enumFromThenTo# x1 x2 y = map I [unsafeToInteger x1, unsafeToInteger x2 .. unsafeToInteger y]
instance KnownNat n => Bounded (Index n) where
  minBound = fromInteger# 0
  maxBound = maxBound#
maxBound# :: KnownNat n => Index n
maxBound# = let res = I (natVal res  1) in res
instance KnownNat n => Num (Index n) where
  (+)         = (+#)
  ()         = (-#)
  (*)         = (*#)
  negate      = (maxBound# -#)
  abs         = id
  signum i    = if i == 0 then 0 else 1
  fromInteger = fromInteger#
(+#),(-#),(*#) :: KnownNat n => Index n -> Index n -> Index n
(+#) (I a) (I b) = fromInteger_INLINE $ a + b
(-#) (I a) (I b) = fromInteger_INLINE $ a  b
(*#) (I a) (I b) = fromInteger_INLINE $ a * b
fromInteger# :: KnownNat n => Integer -> Index n
fromInteger# = fromInteger_INLINE
fromInteger_INLINE :: forall n . KnownNat n => Integer -> Index n
fromInteger_INLINE i = bound `seq` if i' == i then I i else err
  where
    bound = natVal (Proxy :: Proxy n)
    i'    = i `mod` bound
    err   = error ("CLaSH.Sized.Index: result " ++ show i ++
                   " is out of bounds: [0.." ++ show (bound  1) ++ "]")
instance ExtendingNum (Index m) (Index n) where
  type AResult (Index m) (Index n) = Index (m + n  1)
  plus  = plus#
  minus = minus#
  type MResult (Index m) (Index n) = Index (((m  1) * (n  1)) + 1)
  times = times#
plus#, minus# :: Index m -> Index n -> Index (m + n  1)
plus# (I a) (I b) = I (a + b)
minus# (I a) (I b) =
  let z   = a  b
      err = error ("CLaSH.Sized.Index.minus: result " ++ show z ++
                   " is smaller than 0")
      res = if z < 0 then err else I z
  in  res
times# :: Index m -> Index n -> Index (((m  1) * (n  1)) + 1)
times# (I a) (I b) = I (a * b)
instance KnownNat n => Real (Index n) where
  toRational = toRational . toInteger#
instance KnownNat n => Integral (Index n) where
  quot        = quot#
  rem         = rem#
  div         = quot#
  mod         = rem#
  quotRem n d = (n `quot#` d,n `rem#` d)
  divMod  n d = (n `quot#` d,n `rem#` d)
  toInteger   = toInteger#
quot#,rem# :: Index n -> Index n -> Index n
(I a) `quot#` (I b) = I (a `div` b)
(I a) `rem#` (I b) = I (a `rem` b)
toInteger# :: Index n -> Integer
toInteger# (I n) = n
instance Resize Index where
  resize     = resize#
  zeroExtend = resize#
  signExtend = resize#
  truncateB  = resize#
resize# :: KnownNat m => Index n -> Index m
resize# (I i) = fromInteger_INLINE i
instance KnownNat n => Lift (Index n) where
  lift u@(I i) = sigE [| fromInteger# i |] (decIndex (natVal u))
  
decIndex :: Integer -> TypeQ
decIndex n = appT (conT ''Index) (litT $ numTyLit n)
instance Show (Index n) where
  show (I i) = show i
  
instance KnownNat n => Read (Index n) where
  readPrec = fromIntegral <$> (readPrec :: ReadPrec Word)
instance KnownNat n => Default (Index n) where
  def = fromInteger# 0
instance KnownNat n => Arbitrary (Index n) where
  arbitrary = arbitraryBoundedIntegral
  shrink    = shrinkIndex
shrinkIndex :: KnownNat n => Index n -> [Index n]
shrinkIndex x | natVal x < 3 = case toInteger x of
                                 1 -> [0]
                                 _ -> []
              
              
              | otherwise    = shrinkIntegral x
instance KnownNat n => CoArbitrary (Index n) where
  coarbitrary = coarbitraryIntegral