module Data.Layout.Vector (
      Codec
    , compile
    , StorableVector (..)
    , encodeVectors
    , decodeVector
    ) where
import           Control.Monad (when)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.ByteString.Internal (ByteString(..))
import qualified Data.Vector.Storable as V
import           Data.Word (Word8, Word32)
import           Foreign.C (CInt (..))
import           Foreign.ForeignPtr ()
import           Foreign.ForeignPtr (ForeignPtr, withForeignPtr, castForeignPtr)
import           Foreign.Marshal.Alloc (alloca)
import           Foreign.Ptr (Ptr, plusPtr, castPtr)
import           Foreign.Storable (Storable, peek, poke, sizeOf)
import           System.IO.Unsafe (unsafePerformIO)
import           Data.Layout.ForeignPtr (mallocPlainForeignPtrBytes)
import qualified Data.Layout.Language as L
import           Data.Layout.Internal (Layout(..), ByteOrder(..))
data StorableVector where
    SV :: Storable a => V.Vector a -> StorableVector
encodeVectors :: [(Codec, StorableVector)] -> ByteString
encodeVectors [] = B.empty
encodeVectors xs@((codec, SV vec):_) = unsafePerformIO $ do
    
    bp <- mallocPlainForeignPtrBytes bstrBytes
    
    mapM_ (go (DstPtr bp 0)) xs
    
    return (PS bp 0 bstrBytes)
  where
    bstrBytes = encodeReps codec vec * encodedSize codec
    go dp (c, SV v) = encodeVector c dp v
encodeVector :: forall a. Storable a => Codec -> DstPtr -> V.Vector a -> IO ()
encodeVector codec dstPtr vec = do
    
    checkValueSize "encodeVector" codec vec
    
    let vp = castForeignPtr (fst (V.unsafeToForeignPtr0 vec))
    
    encode codec n dstPtr (SrcPtr vp 0)
  where
    n = encodeReps codec vec
encodeReps :: Storable a => Codec -> V.Vector a -> Int
encodeReps c v = repetitions "Vector" (vectorSize v) (decodedSize c)
decodeVector :: forall a. Storable a => Codec -> ByteString -> V.Vector a
decodeVector codec bstr@(PS bp bpOff _) = unsafePerformIO $ do
    
    checkValueSize "decodeVector" codec (V.empty :: V.Vector a)
    
    vp <- mallocPlainForeignPtrBytes vectorBytes
    
    decode codec n (DstPtr vp 0) (SrcPtr bp bpOff)
    
    return (V.unsafeFromForeignPtr0 (castForeignPtr vp) vectorElems)
  where
    n = repetitions "ByteString" (B.length bstr) (encodedSize codec)
    vectorBytes = n * decodedSize codec
    vectorElems = n * valueCount codec
vectorSize :: forall a. Storable a => V.Vector a -> Int
vectorSize v = V.length v * sizeOf (undefined :: a)
checkValueSize :: forall a m. (Storable a, Monad m)
               => String -> Codec -> V.Vector a -> m ()
checkValueSize fn codec _ =
    when (codecValueSize /= vectorElemSize) (error errorMsg)
  where
    codecValueSize = valueSize codec
    vectorElemSize = sizeOf (undefined :: a)
    errorMsg = concat
      [ "Data.Layout.Vector.", fn, ": "
      , "Value size mismatch. The value size of a codec ("
      , show codecValueSize, " bytes) did not match the size of "
      , "individual elements (", show vectorElemSize, " bytes) in "
      , "the corresponding vector. This means that the wrong type "
      , "of vector is being used for a given codec." ]
repetitions :: String -> Int -> Int -> Int
repetitions sourceName sourceBytes codecBytes =
    if leftover /= 0 then error msg else n
  where
    (n, leftover) = sourceBytes `quotRem` codecBytes
    msg = concat
      [ "Data.Layout.Vector.encodeReps: "
      , "The source ", sourceName, " is not a multiple of "
      , show codecBytes, " bytes, as required by the Codec. "
      , show sourceBytes, " bytes were provided, which leaves "
      , show leftover, " bytes unused." ]
data DstPtr = DstPtr  !(ForeignPtr Word8)
                      !Int
data SrcPtr = SrcPtr  !(ForeignPtr Word8)
                      !Int
data Codec = Codec
    { encode      :: Int -> DstPtr -> SrcPtr -> IO ()
    , decode      :: Int -> DstPtr -> SrcPtr -> IO ()
    , encodedSize :: Int
    , decodedSize :: Int
    , valueCount  :: Int
    , valueSize   :: Int
    }
compile :: Layout -> Codec
compile layout = Codec
    { encode      = runCodec copyInfo c_encode
    , decode      = runCodec copyInfo c_decode
    , encodedSize = L.size layout
    , decodedSize = L.valueSizeN layout
    , valueCount  = L.valueCount layout
    , valueSize   = L.valueSize1 layout
    }
  where
    copyInfo = buildCopyInfo layout
runCodec :: CopyInfo -> CodecFn -> Int -> DstPtr -> SrcPtr -> IO ()
runCodec info c_codec_fn reps (DstPtr dstFP dstOff) (SrcPtr srcFP srcOff) =
    
    withForeignPtr dstFP $ \dst0 -> do
    withForeignPtr srcFP $ \src0 -> do
    
    let dst = dst0 `plusPtr` dstOff
        src = src0 `plusPtr` srcOff
    
    V.unsafeWith (ciOffsets info) $ \offsets -> do
    
    err <- c_codec_fn
        dst src
        (fromIntegral reps)
        (ciNumOffsets info)
        offsets
        (ciNumValues info)
        (ciValueSize info)
        (ciSwapBytes info)
    
    case err of
      0 -> return ()
      1 -> error ("runCodec: invalid value size: " ++ show (ciValueSize info))
      _ -> error "runCodec: unknown error"
data CopyInfo = CopyInfo
  { ciOffsets    :: V.Vector CInt
  , ciNumValues  :: CInt
  , ciValueSize  :: CInt
  , ciSwapBytes  :: CInt
  } deriving (Show)
ciNumOffsets :: CopyInfo -> CInt
ciNumOffsets = fromIntegral . V.length . ciOffsets
type SkipCopyOp = CInt
buildCopyInfo :: Layout -> CopyInfo
buildCopyInfo layout =
    CopyInfo { ciOffsets, ciNumValues, ciValueSize, ciSwapBytes }
  where
    ciNumValues = copySize `quot` ciValueSize
    ciValueSize = fromIntegral (L.valueSize1 layout)
    ciSwapBytes = if needsByteSwap layout then 1 else 0
    (copySize, ciOffsets) = (splitOps . optimize . toSkipCopyOps) layout
    
    toSkipCopyOps :: Layout -> V.Vector SkipCopyOp
    toSkipCopyOps = go
      where
        go v@(Value _)   = V.singleton (copyOp (L.valueSize1 v))
        go (Offset n xs) = skipOp n `V.cons` go xs
        go (Repeat n xs) = V.concat (replicate n (go xs))
        go (Group  n xs) = go xs `V.snoc` skipOp (n  L.size xs)
        
        skipOp = fromIntegral
        
        copyOp n = fromIntegral (n)
    
    optimize :: V.Vector SkipCopyOp -> V.Vector SkipCopyOp
    optimize = skips
      where
        skips  = sumWhile (> 0) copies
        copies = sumWhile (<= 0) skips
        sumWhile p k xs
            | V.null xs = V.empty
            | otherwise = let (ys, zs) = V.span p xs
                          in case V.sum ys of
                              0 -> k zs
                              s -> s `V.cons` k zs
    
    
    
    splitOps :: V.Vector SkipCopyOp -> (CInt, V.Vector CInt)
    splitOps = split . head0 . last0
      where
        head0 xs | V.head xs < 0 = 0 `V.cons` xs
                 | otherwise     = xs
        last0 xs | V.last xs < 0 = xs `V.snoc` 0
                 | otherwise     = xs
        split :: V.Vector SkipCopyOp -> (CInt, V.Vector CInt)
        split xs = (copyOp, V.filter isSkip xs)
          where
            copyOp = V.head (V.dropWhile (>= 0) xs)
            isSkip x | x == copyOp = False 
                     | x >= 0      = True  
                     | otherwise   = error $
                         "buildCopyInfo: invalid copy operation " ++
                         "(expected <" ++ show (copyOp) ++ " bytes>," ++
                         " actual <" ++ show (x) ++ " bytes>)"
needsByteSwap :: Layout -> Bool
needsByteSwap x = case L.byteOrder x of
    NoByteOrder  -> False
    LittleEndian -> hostIsBigEndian
    BigEndian    -> hostIsLittleEndian
endianCheck :: Word8
endianCheck = unsafePerformIO $ alloca $ \p -> do
    poke p (0x01020304 :: Word32)
    peek (castPtr p :: Ptr Word8)
hostIsLittleEndian :: Bool
hostIsLittleEndian = endianCheck == 4
hostIsBigEndian :: Bool
hostIsBigEndian = endianCheck == 1
type CodecFn =
       Ptr Word8 
    -> Ptr Word8 
    -> CInt      
    -> CInt      
    -> Ptr CInt  
    -> CInt      
    -> CInt      
    -> CInt      
    -> IO CInt   
foreign import ccall unsafe "data_layout_encode"
    c_encode :: CodecFn
foreign import ccall unsafe "data_layout_decode"
    c_decode :: CodecFn