module Dahdit.Iface
  ( BinaryTarget (..)
  , getTarget
  , putTarget
  , MutBinaryTarget (..)
  , mutPutTargetOffset
  , mutPutTarget
  , decode
  , decodeFile
  , encode
  , encodeFile
  , mutEncode
  )
where

import Control.Monad.Primitive (PrimBase, PrimMonad (..))
import Control.Monad.ST (runST)
import Dahdit.Binary (Binary (..))
import Dahdit.Free (Get, Put)
import Dahdit.Mem (allocBAMem, allocPtrMem, freezeBAMem, freezeBSMem, freezeSBSMem, freezeVecMem, mutAllocBAMem, mutAllocVecMem, mutFreezeBAMem, mutFreezeVecMem, viewBSMem, viewSBSMem, viewVecMem)
import Dahdit.Run (GetError, runCount, runGetInternal, runPutInternal)
import Dahdit.Sizes (ByteCount (..), ByteSized (..))
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Short (ShortByteString)
import qualified Data.ByteString.Short as BSS
import Data.Coerce (coerce)
import Data.Primitive.ByteArray (ByteArray, MutableByteArray, sizeofByteArray)
import Data.Vector.Storable (Vector)
import qualified Data.Vector.Storable as VS
import Data.Vector.Storable.Mutable (MVector)
import Data.Word (Word8)

-- | Abstracts over the sources we can read from / sinks we can render to.
class BinaryTarget z where
  -- | Put an action to the sink with the given length.
  -- Prefer 'putTarget' to safely count capacity, or use 'encode' to use byte size.
  putTargetUnsafe :: Put -> ByteCount -> z

  -- | Get a value from the source given a starting offset, returning a result and final offset.
  -- On error, the offset will indicate where in the source the error occurred.
  getTargetOffset :: ByteCount -> Get a -> z -> (Either GetError a, ByteCount)

-- | Get a value from the source, returning a result and final offset.
getTarget :: BinaryTarget z => Get a -> z -> (Either GetError a, ByteCount)
getTarget :: forall z a.
BinaryTarget z =>
Get a -> z -> (Either GetError a, ByteCount)
getTarget = forall z a.
BinaryTarget z =>
ByteCount -> Get a -> z -> (Either GetError a, ByteCount)
getTargetOffset ByteCount
0

-- | Put an action to the sink with calculated capacity.
putTarget :: BinaryTarget z => Put -> z
putTarget :: forall z. BinaryTarget z => Put -> z
putTarget Put
p = forall z. BinaryTarget z => Put -> ByteCount -> z
putTargetUnsafe Put
p (Put -> ByteCount
runCount Put
p)

class MutBinaryTarget m z where
  mutPutTargetOffsetUnsafe :: ByteCount -> Put -> ByteCount -> z -> m ByteCount

mutPutTargetOffset :: MutBinaryTarget m z => ByteCount -> Put -> z -> m ByteCount
mutPutTargetOffset :: forall (m :: * -> *) z.
MutBinaryTarget m z =>
ByteCount -> Put -> z -> m ByteCount
mutPutTargetOffset ByteCount
off Put
p = forall (m :: * -> *) z.
MutBinaryTarget m z =>
ByteCount -> Put -> ByteCount -> z -> m ByteCount
mutPutTargetOffsetUnsafe ByteCount
off Put
p (Put -> ByteCount
runCount Put
p)

mutPutTarget :: MutBinaryTarget m z => Put -> z -> m ByteCount
mutPutTarget :: forall (m :: * -> *) z.
MutBinaryTarget m z =>
Put -> z -> m ByteCount
mutPutTarget = forall (m :: * -> *) z.
MutBinaryTarget m z =>
ByteCount -> Put -> z -> m ByteCount
mutPutTargetOffset ByteCount
0

instance BinaryTarget ShortByteString where
  getTargetOffset :: forall a.
ByteCount
-> Get a -> ShortByteString -> (Either GetError a, ByteCount)
getTargetOffset = forall a.
ByteCount
-> Get a -> ShortByteString -> (Either GetError a, ByteCount)
runGetSBS
  putTargetUnsafe :: Put -> ByteCount -> ShortByteString
putTargetUnsafe = Put -> ByteCount -> ShortByteString
runPutSBS

instance BinaryTarget ByteString where
  getTargetOffset :: forall a.
ByteCount -> Get a -> ByteString -> (Either GetError a, ByteCount)
getTargetOffset = forall a.
ByteCount -> Get a -> ByteString -> (Either GetError a, ByteCount)
runGetBS
  putTargetUnsafe :: Put -> ByteCount -> ByteString
putTargetUnsafe = Put -> ByteCount -> ByteString
runPutBS

instance BinaryTarget ByteArray where
  getTargetOffset :: forall a.
ByteCount -> Get a -> ByteArray -> (Either GetError a, ByteCount)
getTargetOffset = forall a.
ByteCount -> Get a -> ByteArray -> (Either GetError a, ByteCount)
runGetBA
  putTargetUnsafe :: Put -> ByteCount -> ByteArray
putTargetUnsafe = Put -> ByteCount -> ByteArray
runPutBA

instance BinaryTarget (Vector Word8) where
  getTargetOffset :: forall a.
ByteCount
-> Get a -> Vector Word8 -> (Either GetError a, ByteCount)
getTargetOffset = forall a.
ByteCount
-> Get a -> Vector Word8 -> (Either GetError a, ByteCount)
runGetVec
  putTargetUnsafe :: Put -> ByteCount -> Vector Word8
putTargetUnsafe = Put -> ByteCount -> Vector Word8
runPutVec

instance (PrimBase m, s ~ PrimState m) => MutBinaryTarget m (MutableByteArray s) where
  mutPutTargetOffsetUnsafe :: ByteCount -> Put -> ByteCount -> MutableByteArray s -> m ByteCount
mutPutTargetOffsetUnsafe = forall (m :: * -> *).
PrimBase m =>
ByteCount
-> Put
-> ByteCount
-> MutableByteArray (PrimState m)
-> m ByteCount
runMutPutBA

instance (PrimBase m, s ~ PrimState m) => MutBinaryTarget m (MVector s Word8) where
  mutPutTargetOffsetUnsafe :: ByteCount -> Put -> ByteCount -> MVector s Word8 -> m ByteCount
mutPutTargetOffsetUnsafe = forall (m :: * -> *).
PrimBase m =>
ByteCount
-> Put -> ByteCount -> MVector (PrimState m) Word8 -> m ByteCount
runMutPutVec

-- | Decode a value from a source returning a result and consumed byte count.
decode :: (Binary a, BinaryTarget z) => z -> (Either GetError a, ByteCount)
decode :: forall a z.
(Binary a, BinaryTarget z) =>
z -> (Either GetError a, ByteCount)
decode = forall z a.
BinaryTarget z =>
Get a -> z -> (Either GetError a, ByteCount)
getTarget forall a. Binary a => Get a
get

-- | Decode a value from a file.
decodeFile :: Binary a => FilePath -> IO (Either GetError a, ByteCount)
decodeFile :: forall a. Binary a => FilePath -> IO (Either GetError a, ByteCount)
decodeFile = forall a. Get a -> FilePath -> IO (Either GetError a, ByteCount)
runGetFile forall a. Binary a => Get a
get

-- | Encode a value to a sink.
encode :: (Binary a, ByteSized a, BinaryTarget z) => a -> z
encode :: forall a z. (Binary a, ByteSized a, BinaryTarget z) => a -> z
encode a
a = forall z. BinaryTarget z => Put -> ByteCount -> z
putTargetUnsafe (forall a. Binary a => a -> Put
put a
a) (forall a. ByteSized a => a -> ByteCount
byteSize a
a)

-- | Encode a value to a file.
encodeFile :: (Binary a, ByteSized a) => a -> FilePath -> IO ()
encodeFile :: forall a. (Binary a, ByteSized a) => a -> FilePath -> IO ()
encodeFile a
a = Put -> ByteCount -> FilePath -> IO ()
runPutFile (forall a. Binary a => a -> Put
put a
a) (forall a. ByteSized a => a -> ByteCount
byteSize a
a)

-- | Encode a value to a mutable buffer, returning number of bytes filled.
mutEncode :: (Binary a, ByteSized a, MutBinaryTarget m z) => a -> z -> m ByteCount
mutEncode :: forall a (m :: * -> *) z.
(Binary a, ByteSized a, MutBinaryTarget m z) =>
a -> z -> m ByteCount
mutEncode a
a = forall (m :: * -> *) z.
MutBinaryTarget m z =>
ByteCount -> Put -> ByteCount -> z -> m ByteCount
mutPutTargetOffsetUnsafe ByteCount
0 (forall a. Binary a => a -> Put
put a
a) (forall a. ByteSized a => a -> ByteCount
byteSize a
a)

runGetBA :: ByteCount -> Get a -> ByteArray -> (Either GetError a, ByteCount)
runGetBA :: forall a.
ByteCount -> Get a -> ByteArray -> (Either GetError a, ByteCount)
runGetBA ByteCount
off Get a
act ByteArray
ba = forall r a.
ReadMem r =>
ByteCount
-> Get a -> ByteCount -> r -> (Either GetError a, ByteCount)
runGetInternal ByteCount
off Get a
act (coerce :: forall a b. Coercible a b => a -> b
coerce (ByteArray -> Int
sizeofByteArray ByteArray
ba)) ByteArray
ba

runGetSBS :: ByteCount -> Get a -> ShortByteString -> (Either GetError a, ByteCount)
runGetSBS :: forall a.
ByteCount
-> Get a -> ShortByteString -> (Either GetError a, ByteCount)
runGetSBS ByteCount
off Get a
act ShortByteString
sbs = forall r a.
ReadMem r =>
ByteCount
-> Get a -> ByteCount -> r -> (Either GetError a, ByteCount)
runGetInternal ByteCount
off Get a
act (coerce :: forall a b. Coercible a b => a -> b
coerce (ShortByteString -> Int
BSS.length ShortByteString
sbs)) (ShortByteString -> ByteArray
viewSBSMem ShortByteString
sbs)

runGetBS :: ByteCount -> Get a -> ByteString -> (Either GetError a, ByteCount)
runGetBS :: forall a.
ByteCount -> Get a -> ByteString -> (Either GetError a, ByteCount)
runGetBS ByteCount
off Get a
act ByteString
bs = forall r a.
ReadMem r =>
ByteCount
-> Get a -> ByteCount -> r -> (Either GetError a, ByteCount)
runGetInternal ByteCount
off Get a
act (coerce :: forall a b. Coercible a b => a -> b
coerce (ByteString -> Int
BS.length ByteString
bs)) (ByteString -> Ptr Word8
viewBSMem ByteString
bs)

runGetVec :: ByteCount -> Get a -> Vector Word8 -> (Either GetError a, ByteCount)
runGetVec :: forall a.
ByteCount
-> Get a -> Vector Word8 -> (Either GetError a, ByteCount)
runGetVec ByteCount
off Get a
act Vector Word8
vec = forall r a.
ReadMem r =>
ByteCount
-> Get a -> ByteCount -> r -> (Either GetError a, ByteCount)
runGetInternal ByteCount
off Get a
act (coerce :: forall a b. Coercible a b => a -> b
coerce (forall a. Storable a => Vector a -> Int
VS.length Vector Word8
vec)) (Vector Word8 -> Ptr Word8
viewVecMem Vector Word8
vec)

runGetFile :: Get a -> FilePath -> IO (Either GetError a, ByteCount)
runGetFile :: forall a. Get a -> FilePath -> IO (Either GetError a, ByteCount)
runGetFile Get a
act FilePath
fp = do
  ByteString
bs <- FilePath -> IO ByteString
BS.readFile FilePath
fp
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a.
ByteCount -> Get a -> ByteString -> (Either GetError a, ByteCount)
runGetBS ByteCount
0 Get a
act ByteString
bs)

runPutBA :: Put -> ByteCount -> ByteArray
runPutBA :: Put -> ByteCount -> ByteArray
runPutBA Put
act ByteCount
len = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) (q :: * -> *) z.
(PrimBase m, WriteMem q m) =>
ByteCount
-> Put
-> ByteCount
-> (ByteCount -> ByteCount -> m (q (PrimState m), Maybe (IO ())))
-> (q (PrimState m) -> ByteCount -> ByteCount -> m z)
-> m z
runPutInternal ByteCount
0 Put
act ByteCount
len forall (m :: * -> *).
PrimMonad m =>
ByteCount
-> ByteCount -> m (MutableByteArray (PrimState m), Maybe (IO ()))
allocBAMem forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> ByteCount -> ByteCount -> m ByteArray
freezeBAMem)

runPutSBS :: Put -> ByteCount -> ShortByteString
runPutSBS :: Put -> ByteCount -> ShortByteString
runPutSBS Put
act ByteCount
len = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) (q :: * -> *) z.
(PrimBase m, WriteMem q m) =>
ByteCount
-> Put
-> ByteCount
-> (ByteCount -> ByteCount -> m (q (PrimState m), Maybe (IO ())))
-> (q (PrimState m) -> ByteCount -> ByteCount -> m z)
-> m z
runPutInternal ByteCount
0 Put
act ByteCount
len forall (m :: * -> *).
PrimMonad m =>
ByteCount
-> ByteCount -> m (MutableByteArray (PrimState m), Maybe (IO ()))
allocBAMem forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> ByteCount -> ByteCount -> m ShortByteString
freezeSBSMem)

runPutBS :: Put -> ByteCount -> ByteString
runPutBS :: Put -> ByteCount -> ByteString
runPutBS Put
act ByteCount
len = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) (q :: * -> *) z.
(PrimBase m, WriteMem q m) =>
ByteCount
-> Put
-> ByteCount
-> (ByteCount -> ByteCount -> m (q (PrimState m), Maybe (IO ())))
-> (q (PrimState m) -> ByteCount -> ByteCount -> m z)
-> m z
runPutInternal ByteCount
0 Put
act ByteCount
len forall (m :: * -> *).
PrimMonad m =>
ByteCount -> ByteCount -> m (IxPtr (PrimState m), Maybe (IO ()))
allocPtrMem forall (m :: * -> *).
PrimMonad m =>
IxPtr (PrimState m) -> ByteCount -> ByteCount -> m ByteString
freezeBSMem)

runPutVec :: Put -> ByteCount -> Vector Word8
runPutVec :: Put -> ByteCount -> Vector Word8
runPutVec Put
act ByteCount
len = forall a. (forall s. ST s a) -> a
runST (forall (m :: * -> *) (q :: * -> *) z.
(PrimBase m, WriteMem q m) =>
ByteCount
-> Put
-> ByteCount
-> (ByteCount -> ByteCount -> m (q (PrimState m), Maybe (IO ())))
-> (q (PrimState m) -> ByteCount -> ByteCount -> m z)
-> m z
runPutInternal ByteCount
0 Put
act ByteCount
len forall (m :: * -> *).
PrimMonad m =>
ByteCount -> ByteCount -> m (IxPtr (PrimState m), Maybe (IO ()))
allocPtrMem forall (m :: * -> *).
PrimMonad m =>
IxPtr (PrimState m) -> ByteCount -> ByteCount -> m (Vector Word8)
freezeVecMem)

runPutFile :: Put -> ByteCount -> FilePath -> IO ()
runPutFile :: Put -> ByteCount -> FilePath -> IO ()
runPutFile Put
act ByteCount
cap FilePath
fp =
  let bs :: ByteString
bs = Put -> ByteCount -> ByteString
runPutBS Put
act ByteCount
cap
  in  FilePath -> ByteString -> IO ()
BS.writeFile FilePath
fp ByteString
bs

runMutPutBA :: PrimBase m => ByteCount -> Put -> ByteCount -> MutableByteArray (PrimState m) -> m ByteCount
runMutPutBA :: forall (m :: * -> *).
PrimBase m =>
ByteCount
-> Put
-> ByteCount
-> MutableByteArray (PrimState m)
-> m ByteCount
runMutPutBA ByteCount
off Put
act ByteCount
len MutableByteArray (PrimState m)
marr = forall (m :: * -> *) (q :: * -> *) z.
(PrimBase m, WriteMem q m) =>
ByteCount
-> Put
-> ByteCount
-> (ByteCount -> ByteCount -> m (q (PrimState m), Maybe (IO ())))
-> (q (PrimState m) -> ByteCount -> ByteCount -> m z)
-> m z
runPutInternal ByteCount
off Put
act ByteCount
len (forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> ByteCount
-> ByteCount
-> m (MutableByteArray (PrimState m), Maybe (IO ()))
mutAllocBAMem MutableByteArray (PrimState m)
marr) forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> ByteCount -> ByteCount -> m ByteCount
mutFreezeBAMem

runMutPutVec :: PrimBase m => ByteCount -> Put -> ByteCount -> MVector (PrimState m) Word8 -> m ByteCount
runMutPutVec :: forall (m :: * -> *).
PrimBase m =>
ByteCount
-> Put -> ByteCount -> MVector (PrimState m) Word8 -> m ByteCount
runMutPutVec ByteCount
off Put
act ByteCount
len MVector (PrimState m) Word8
mvec = forall (m :: * -> *) (q :: * -> *) z.
(PrimBase m, WriteMem q m) =>
ByteCount
-> Put
-> ByteCount
-> (ByteCount -> ByteCount -> m (q (PrimState m), Maybe (IO ())))
-> (q (PrimState m) -> ByteCount -> ByteCount -> m z)
-> m z
runPutInternal ByteCount
off Put
act ByteCount
len (forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Word8
-> ByteCount -> ByteCount -> m (IxPtr (PrimState m), Maybe (IO ()))
mutAllocVecMem MVector (PrimState m) Word8
mvec) forall (m :: * -> *).
PrimMonad m =>
IxPtr (PrimState m) -> ByteCount -> ByteCount -> m ByteCount
mutFreezeVecMem