module Data.ByteString.Lazy.Num
    ( numCompare
    , NumBS(..)
    ) where

import qualified Data.ByteString.Lazy as L
import Data.Bits
import Data.Binary (encode)
import Data.List (foldl', mapAccumL)
import Data.Word

data NumBS = NBS {unNBS :: L.ByteString}
  deriving (Eq, Ord, Show)

-- Little Endian ByteStrings
instance Num NumBS where
    a + b = byteStrOp (+) a b
    (*) = asInteger2 (*)
    a - b = byteStrOp (-) a b
    negate a = (NBS $ L.replicate (L.length (unNBS a)) 0) - a
    abs a = a
    signum a = a `seq` NBS (L.pack [1])
    fromInteger i = NBS $ L.pack $ go i []
      where
      go :: Integer -> [Word8] -> [Word8]
      go i acc | i < 0 = L.unpack $ unNBS $ (NBS $ L.pack [0]) - (NBS $ L.pack $ go (-i) acc)
      go 0 acc = reverse acc
      go i acc = go (i `shiftR` 8) (fromIntegral (i .&. 0xFF):acc)

instance Integral NumBS where
    quot = asInteger2 quot
    rem = asInteger2 rem
    div = asInteger2 div
    mod = asInteger2 mod
    quotRem a b =
        let (x,y) = quotRem (fromIntegral a) (fromIntegral b :: Integer)
        in (fromIntegral x, fromIntegral y)
    divMod a b =
        let (x,y) = divMod (fromIntegral a) (fromIntegral b :: Integer)
        in (fromIntegral x, fromIntegral y)
    toInteger (NBS a) = snd $ foldl' acc (0,0) (L.unpack a)
      where
      acc :: (Integer, Integer) -> Word8 -> (Integer, Integer)
      acc (i, tot) n = (i+1, tot + (fromIntegral n `shiftL` fromIntegral (i*8)))
      -- FIXME use of 'Int' in 'shiftL' causes a maxbound issue

instance Real NumBS where
    toRational = toRational . fromIntegral 

instance Enum NumBS where
    succ a = if isMaxBound a then error "succ maxBound" else a + 1
    pred a = if isMinBound a then error "pred minBound" else a - 1
    toEnum i = NBS $ encode i
    fromEnum a = fromIntegral a
    enumFrom a = if isMaxBound a then [a] else a : (enumFrom (succ a))
    enumFromThen start cnt = normalized go start cnt
        where
        go s c =
          if s > (NBS $ L.replicate (L.length (unNBS c)) 0xFF) - c
            then [s]
            else s : enumFromThen (s + c) c
    enumFromTo start end = normalized go start end
      where go s e | numCompare s e == GT = []
                   | otherwise =
                       if isMaxBound s then [s] else s : enumFromTo (succ s) e
    enumFromThenTo s c e = 
        takeWhile (\x -> numCompare x e /= GT) (enumFromThen s c)

isMaxBound :: NumBS -> Bool
isMaxBound = L.all (== 0xFF) . unNBS

isMinBound :: NumBS -> Bool
isMinBound = L.all (== 0x0) . unNBS

-- instance Ord L.ByteString where
numCompare a b = 
      let byteCmp = normalized (\(NBS x) (NBS y) -> reverse (L.zipWith compare x y)) a b
      in case dropWhile (== EQ) byteCmp of
           (LT:_) -> LT
           (GT:_) -> GT
           _      -> EQ

byteStrOp :: (Int -> Int -> Int) -> NumBS -> NumBS -> NumBS
byteStrOp op (NBS a) (NBS b) =
    let (c,ws) = mapAccumL (combWords op) 0 (L.zip a' b')
    in NBS (L.pack ws)
  where
  (a',b') = normalize a b

combWords :: (Int -> Int -> Int) -> Int -> (Word8,Word8) -> (Int, Word8)
combWords op carry (a,b) = (c, fromIntegral r)
  where
  p :: Int
  p = (fromIntegral a `op` fromIntegral b) + carry
  (c,r) = quotRem p 256

normalized :: (NumBS -> NumBS -> a) ->  -- The op
              NumBS -> NumBS -> a   -- lps to normalize
normalized op a b = let (a', b') = normalize (unNBS a) (unNBS b) in op (NBS a') (NBS b')

normalize :: L.ByteString -> L.ByteString -> (L.ByteString, L.ByteString)
normalize a b = (a',b')
  where
  aPad = L.replicate (L.length b - L.length a) 0
  bPad = L.replicate (L.length a - L.length b) 0
  a' = L.append a aPad
  b' = L.append b bPad

asInteger :: (Integer -> Integer) -> NumBS -> NumBS
asInteger op = fromIntegral . op . fromIntegral

asInteger2 :: (Integer -> Integer -> Integer) ->
              NumBS -> NumBS -> NumBS
asInteger2 op a b = fromIntegral $ fromIntegral a `op` fromIntegral b

instance Bits NumBS where
    (.&.) = normalized (byteStrOp (.&.))
    (.|.) = normalized (byteStrOp (.|.))
    xor = normalized (byteStrOp xor)
    complement (NBS a) =  NBS $ L.map complement a
    a `shift` i = asInteger (`shift` i) a
    a `rotate` i = asInteger (`rotate` i) a
    bit i =
        let (d,m) = i `quotRem` 8 
        in NBS $ L.snoc (L.replicate (fromIntegral d) 0) (bit m)
    setBit a i = asInteger (`setBit` i) a
    clearBit a i = asInteger (`clearBit` i) a
    complementBit a i = asInteger (`complementBit` i) a
    testBit a i =
        let (d, m) = i `quotRem` 8
        in case L.unpack (L.drop (fromIntegral d) (unNBS a)) of
             []    -> False
             (w:_) -> testBit w m
    bitSize = (*) 8 . fromIntegral . L.length . unNBS
    isSigned _ = False
    shiftL a i = shift a i
    shiftR a i = shift a (-i)
    rotateL a i = rotate a i
    rotateR a i = rotate a (-i)