{- |
Module: Capnp.IO
Description: Utilities for reading and writing values to handles.

This module provides utilities for reading and writing values to and
from file 'Handle's.
-}
{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
module Capnp.IO
    ( sGetMsg
    , sPutMsg
    , M.hGetMsg
    , M.getMsg
    , M.hPutMsg
    , M.putMsg

    , hGetParsed
    , sGetParsed
    , getParsed
    , hPutParsed
    , sPutParsed
    , putParsed
    , hGetRaw
    , getRaw
    , sGetRaw
    ) where

import Data.Bits

import Control.Exception         (throwIO)
import Control.Monad.Trans.Class (lift)
import Network.Simple.TCP        (Socket, recv, sendLazy)
import System.IO                 (Handle, stdin, stdout)
import System.IO.Error           (eofErrorType, mkIOError)

import qualified Data.ByteString         as BS
import qualified Data.ByteString.Builder as BB

import Capnp.Bits           (WordCount, wordsToBytes)
import Capnp.Convert
    (msgToLBS, msgToParsed, msgToRaw, parsedToBuilder, parsedToLBS)
import Capnp.Message        (Mutability(..))
import Capnp.New.Classes    (Parse)
import Capnp.TraversalLimit (evalLimitT)

import qualified Capnp.Message as M
import qualified Capnp.Repr    as R

-- | Like 'hGetMsg', except that it takes a socket instead of a 'Handle'.
sGetMsg :: Socket -> WordCount -> IO (M.Message 'Const)
sGetMsg :: Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit =
    WordCount -> LimitT IO (Message 'Const) -> IO (Message 'Const)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO (Message 'Const) -> IO (Message 'Const))
-> LimitT IO (Message 'Const) -> IO (Message 'Const)
forall a b. (a -> b) -> a -> b
$ LimitT IO Word32
-> (WordCount -> LimitT IO (Segment 'Const))
-> LimitT IO (Message 'Const)
forall (m :: * -> *).
(MonadThrow m, MonadLimit m) =>
m Word32 -> (WordCount -> m (Segment 'Const)) -> m (Message 'Const)
M.readMessage (IO Word32 -> LimitT IO Word32
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift IO Word32
read32) (IO (Segment 'Const) -> LimitT IO (Segment 'Const)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IO (Segment 'Const) -> LimitT IO (Segment 'Const))
-> (WordCount -> IO (Segment 'Const))
-> WordCount
-> LimitT IO (Segment 'Const)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> IO (Segment 'Const)
readSegment)
  where
    read32 :: IO Word32
read32 = do
        ByteString
bytes <- Int -> IO ByteString
recvFull Int
4
        Word32 -> IO Word32
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> IO Word32) -> Word32 -> IO Word32
forall a b. (a -> b) -> a -> b
$
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
0) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL`  Int
0) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
1) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL`  Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
2) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.|.
            (Word8 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString
bytes ByteString -> Int -> Word8
`BS.index` Int
3) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftL` Int
24)
    readSegment :: WordCount -> IO (Segment 'Const)
readSegment !WordCount
words =
        ByteString -> Segment 'Const
M.fromByteString (ByteString -> Segment 'Const)
-> IO ByteString -> IO (Segment 'Const)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
recvFull (ByteCount -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteCount -> Int) -> ByteCount -> Int
forall a b. (a -> b) -> a -> b
$ WordCount -> ByteCount
wordsToBytes WordCount
words)

    -- | Like recv, but (1) never returns less than `count` bytes, (2)
    -- uses `socket`, rather than taking the socket as an argument, and (3)
    -- throws an EOF exception when the connection is closed.
    recvFull :: Int -> IO BS.ByteString
    recvFull :: Int -> IO ByteString
recvFull !Int
count = do
        Maybe ByteString
maybeBytes <- Socket -> Int -> IO (Maybe ByteString)
forall (m :: * -> *).
MonadIO m =>
Socket -> Int -> m (Maybe ByteString)
recv Socket
socket Int
count
        case Maybe ByteString
maybeBytes of
            Maybe ByteString
Nothing ->
                IOError -> IO ByteString
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ByteString) -> IOError -> IO ByteString
forall a b. (a -> b) -> a -> b
$ IOErrorType -> String -> Maybe Handle -> Maybe String -> IOError
mkIOError IOErrorType
eofErrorType String
"Remote socket closed" Maybe Handle
forall a. Maybe a
Nothing Maybe String
forall a. Maybe a
Nothing
            Just ByteString
bytes
                | ByteString -> Int
BS.length ByteString
bytes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
count ->
                    ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bytes
                | Bool
otherwise ->
                    (ByteString
bytes ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>) (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
recvFull (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bytes)

-- | Like 'hPutMsg', except that it takes a 'Socket' instead of a 'Handle'.
sPutMsg :: Socket -> M.Message 'Const -> IO ()
sPutMsg :: Socket -> Message 'Const -> IO ()
sPutMsg Socket
socket = Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket (ByteString -> IO ())
-> (Message 'Const -> ByteString) -> Message 'Const -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message 'Const -> ByteString
msgToLBS

-- | Read a struct from the handle in its parsed form, using the supplied
-- read limit.
hGetParsed :: forall a pa. (R.IsStruct a, Parse a pa) => Handle -> WordCount -> IO pa
hGetParsed :: Handle -> WordCount -> IO pa
hGetParsed Handle
handle WordCount
limit = do
    Message 'Const
msg <- Handle -> WordCount -> IO (Message 'Const)
M.hGetMsg Handle
handle WordCount
limit
    WordCount -> LimitT IO pa -> IO pa
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO pa -> IO pa) -> LimitT IO pa -> IO pa
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO pa
forall a (m :: * -> *) pa.
(ReadCtx m 'Const, IsStruct a, Parse a pa) =>
Message 'Const -> m pa
msgToParsed @a Message 'Const
msg

-- | Read a struct from the socket in its parsed form, using the supplied
-- read limit.
sGetParsed :: forall a pa. (R.IsStruct a, Parse a pa) => Socket -> WordCount -> IO pa
sGetParsed :: Socket -> WordCount -> IO pa
sGetParsed Socket
socket WordCount
limit = do
    Message 'Const
msg <- Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit
    WordCount -> LimitT IO pa -> IO pa
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO pa -> IO pa) -> LimitT IO pa -> IO pa
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO pa
forall a (m :: * -> *) pa.
(ReadCtx m 'Const, IsStruct a, Parse a pa) =>
Message 'Const -> m pa
msgToParsed @a Message 'Const
msg

-- | Read a struct from stdin in its parsed form, using the supplied
-- read limit.
getParsed :: (R.IsStruct a, Parse a pa) => WordCount -> IO pa
getParsed :: WordCount -> IO pa
getParsed = Handle -> WordCount -> IO pa
forall a pa.
(IsStruct a, Parse a pa) =>
Handle -> WordCount -> IO pa
hGetParsed Handle
stdin

-- | Write the parsed form of a struct to the handle
hPutParsed :: (R.IsStruct a, Parse a pa) => Handle -> pa -> IO ()
hPutParsed :: Handle -> pa -> IO ()
hPutParsed Handle
h pa
value = do
    Builder
bb <- WordCount -> LimitT IO Builder -> IO Builder
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO Builder -> IO Builder)
-> LimitT IO Builder -> IO Builder
forall a b. (a -> b) -> a -> b
$ pa -> LimitT IO Builder
forall a (m :: * -> *) pa s.
(RWCtx m s, IsStruct a, Parse a pa) =>
pa -> m Builder
parsedToBuilder pa
value
    Handle -> Builder -> IO ()
BB.hPutBuilder Handle
h Builder
bb

-- | Write the parsed form of a struct to stdout
putParsed :: (R.IsStruct a, Parse a pa) => pa -> IO ()
putParsed :: pa -> IO ()
putParsed = Handle -> pa -> IO ()
forall a pa. (IsStruct a, Parse a pa) => Handle -> pa -> IO ()
hPutParsed Handle
stdout

-- | Write the parsed form of a struct to the socket.
sPutParsed :: (R.IsStruct a, Parse a pa) => Socket -> pa -> IO ()
sPutParsed :: Socket -> pa -> IO ()
sPutParsed Socket
socket pa
value  = do
    ByteString
lbs <- WordCount -> LimitT IO ByteString -> IO ByteString
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
forall a. Bounded a => a
maxBound (LimitT IO ByteString -> IO ByteString)
-> LimitT IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ pa -> LimitT IO ByteString
forall a (m :: * -> *) pa s.
(RWCtx m s, IsStruct a, Parse a pa) =>
pa -> m ByteString
parsedToLBS pa
value
    Socket -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Socket -> ByteString -> m ()
sendLazy Socket
socket ByteString
lbs

-- | Read a struct from the handle using the supplied read limit,
-- and return its root pointer.
hGetRaw :: R.IsStruct a => Handle -> WordCount -> IO (R.Raw a 'Const)
hGetRaw :: Handle -> WordCount -> IO (Raw a 'Const)
hGetRaw Handle
h WordCount
limit = do
    Message 'Const
msg <- Handle -> WordCount -> IO (Message 'Const)
M.hGetMsg Handle
h WordCount
limit
    WordCount -> LimitT IO (Raw a 'Const) -> IO (Raw a 'Const)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO (Raw a 'Const) -> IO (Raw a 'Const))
-> LimitT IO (Raw a 'Const) -> IO (Raw a 'Const)
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO (Raw a 'Const)
forall a (m :: * -> *) (mut :: Mutability).
(ReadCtx m mut, IsStruct a) =>
Message mut -> m (Raw a mut)
msgToRaw Message 'Const
msg

-- | Read a struct from stdin using the supplied read limit,
-- and return its root pointer.
getRaw :: R.IsStruct a => WordCount -> IO (R.Raw a 'Const)
getRaw :: WordCount -> IO (Raw a 'Const)
getRaw = Handle -> WordCount -> IO (Raw a 'Const)
forall a. IsStruct a => Handle -> WordCount -> IO (Raw a 'Const)
hGetRaw Handle
stdin

-- | Read a struct from the socket using the supplied read limit,
-- and return its root pointer.
sGetRaw :: R.IsStruct a => Socket -> WordCount -> IO (R.Raw a 'Const)
sGetRaw :: Socket -> WordCount -> IO (Raw a 'Const)
sGetRaw Socket
socket WordCount
limit = do
    Message 'Const
msg <- Socket -> WordCount -> IO (Message 'Const)
sGetMsg Socket
socket WordCount
limit
    WordCount -> LimitT IO (Raw a 'Const) -> IO (Raw a 'Const)
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT IO (Raw a 'Const) -> IO (Raw a 'Const))
-> LimitT IO (Raw a 'Const) -> IO (Raw a 'Const)
forall a b. (a -> b) -> a -> b
$ Message 'Const -> LimitT IO (Raw a 'Const)
forall a (m :: * -> *) (mut :: Mutability).
(ReadCtx m mut, IsStruct a) =>
Message mut -> m (Raw a mut)
msgToRaw Message 'Const
msg