module System.Nix.Store.Remote.Socket where

import Control.Monad.Except (MonadError, throwError)
import Control.Monad.IO.Class (MonadIO(..))
import Data.ByteString (ByteString)
import Data.Serialize.Get (Get, Result(..))
import Data.Serialize.Put (Put, runPut)
import Network.Socket.ByteString (recv, sendAll)
import System.Nix.Store.Remote.MonadStore (MonadRemoteStore(..), RemoteStoreError(..))
import System.Nix.Store.Remote.Serializer (NixSerializer, runP, runSerialT)
import System.Nix.Store.Remote.Types (ProtoStoreConfig)

import qualified Control.Exception
import qualified Data.ByteString
import qualified Data.Serializer
import qualified Data.Serialize.Get

genericIncremental
  :: ( MonadIO m
     , MonadError RemoteStoreError m
     , Show a
     )
  => m ByteString
  -> Get a
  -> m a
genericIncremental :: forall (m :: * -> *) a.
(MonadIO m, MonadError RemoteStoreError m, Show a) =>
m ByteString -> Get a -> m a
genericIncremental m ByteString
getsome Get a
parser = do
  m ByteString
getsome m ByteString -> (ByteString -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Result a -> m a
go (Result a -> m a) -> (ByteString -> Result a) -> ByteString -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Result a
decoder
 where
  decoder :: ByteString -> Result a
decoder = Get a -> ByteString -> Result a
forall a. Get a -> ByteString -> Result a
Data.Serialize.Get.runGetPartial Get a
parser
  go :: Result a -> m a
go (Done a
x ByteString
leftover) | ByteString
leftover ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
/= ByteString
forall a. Monoid a => a
mempty =
    RemoteStoreError -> m a
forall a. RemoteStoreError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
    (RemoteStoreError -> m a) -> RemoteStoreError -> m a
forall a b. (a -> b) -> a -> b
$ String -> ByteString -> RemoteStoreError
RemoteStoreError_GenericIncrementalLeftovers
        (a -> String
forall a. Show a => a -> String
show a
x)
        ByteString
leftover

  go (Done a
x ByteString
_leftover) = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x

  go (Partial ByteString -> Result a
k) = do
    ByteString
chunk <- m ByteString
getsome
    Result a -> m a
go (ByteString -> Result a
k ByteString
chunk)

  go (Fail String
msg ByteString
leftover) =
    RemoteStoreError -> m a
forall a. RemoteStoreError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError
    (RemoteStoreError -> m a) -> RemoteStoreError -> m a
forall a b. (a -> b) -> a -> b
$ String -> ByteString -> RemoteStoreError
RemoteStoreError_GenericIncrementalFail
        String
msg
        ByteString
leftover

sockGet8
  :: MonadRemoteStore m
  => m ByteString
sockGet8 :: forall (m :: * -> *). MonadRemoteStore m => m ByteString
sockGet8 = do
  Socket
soc <- m Socket
forall (m :: * -> *). MonadRemoteStore m => m Socket
getStoreSocket
  Either SomeException ByteString
eresult <- IO (Either SomeException ByteString)
-> m (Either SomeException ByteString)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SomeException ByteString)
 -> m (Either SomeException ByteString))
-> IO (Either SomeException ByteString)
-> m (Either SomeException ByteString)
forall a b. (a -> b) -> a -> b
$ IO ByteString -> IO (Either SomeException ByteString)
forall e a. Exception e => IO a -> IO (Either e a)
Control.Exception.try (IO ByteString -> IO (Either SomeException ByteString))
-> IO ByteString -> IO (Either SomeException ByteString)
forall a b. (a -> b) -> a -> b
$ Socket -> Int -> IO ByteString
recv Socket
soc Int
8
  case Either SomeException ByteString
eresult of
    Left SomeException
e ->
      RemoteStoreError -> m ByteString
forall a. RemoteStoreError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (RemoteStoreError -> m ByteString)
-> RemoteStoreError -> m ByteString
forall a b. (a -> b) -> a -> b
$ SomeException -> RemoteStoreError
RemoteStoreError_IOException SomeException
e

    Right ByteString
result | ByteString -> Int
Data.ByteString.length ByteString
result Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 ->
      RemoteStoreError -> m ByteString
forall a. RemoteStoreError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError RemoteStoreError
RemoteStoreError_Disconnected

    Right ByteString
result | Bool
otherwise ->
      ByteString -> m ByteString
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
result

sockPut
  :: MonadRemoteStore m
  => Put
  -> m ()
sockPut :: forall (m :: * -> *). MonadRemoteStore m => Put -> m ()
sockPut Put
p = do
  Socket
soc <- m Socket
forall (m :: * -> *). MonadRemoteStore m => m Socket
getStoreSocket
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendAll Socket
soc (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPut Put
p

sockPutS
  :: ( MonadRemoteStore m
     , MonadError e m
     )
  => NixSerializer ProtoStoreConfig e a
  -> a
  -> m ()
sockPutS :: forall (m :: * -> *) e a.
(MonadRemoteStore m, MonadError e m) =>
NixSerializer ProtoStoreConfig e a -> a -> m ()
sockPutS NixSerializer ProtoStoreConfig e a
s a
a = do
  ProtoStoreConfig
cfg <- m ProtoStoreConfig
forall (m :: * -> *). MonadRemoteStore m => m ProtoStoreConfig
getConfig
  Socket
sock <- m Socket
forall (m :: * -> *). MonadRemoteStore m => m Socket
getStoreSocket
  case NixSerializer ProtoStoreConfig e a
-> ProtoStoreConfig -> a -> Either e ByteString
forall r e a. NixSerializer r e a -> r -> a -> Either e ByteString
runP NixSerializer ProtoStoreConfig e a
s ProtoStoreConfig
cfg a
a of
    Right ByteString
x -> IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Socket -> ByteString -> IO ()
sendAll Socket
sock ByteString
x
    Left e
e -> e -> m ()
forall a. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError e
e

sockGetS
  :: ( MonadRemoteStore m
     , MonadError e m
     , Show a
     , Show e
     )
  => NixSerializer ProtoStoreConfig e a
  -> m a
sockGetS :: forall (m :: * -> *) e a.
(MonadRemoteStore m, MonadError e m, Show a, Show e) =>
NixSerializer ProtoStoreConfig e a -> m a
sockGetS NixSerializer ProtoStoreConfig e a
s = do
  ProtoStoreConfig
cfg <- m ProtoStoreConfig
forall (m :: * -> *). MonadRemoteStore m => m ProtoStoreConfig
getConfig
  Either e a
res <- m ByteString -> Get (Either e a) -> m (Either e a)
forall (m :: * -> *) a.
(MonadIO m, MonadError RemoteStoreError m, Show a) =>
m ByteString -> Get a -> m a
genericIncremental m ByteString
forall (m :: * -> *). MonadRemoteStore m => m ByteString
sockGet8
    (Get (Either e a) -> m (Either e a))
-> Get (Either e a) -> m (Either e a)
forall a b. (a -> b) -> a -> b
$ ProtoStoreConfig
-> SerialT ProtoStoreConfig e Get a -> Get (Either e a)
forall (m :: * -> *) r e a.
Monad m =>
r -> SerialT r e m a -> m (Either e a)
runSerialT ProtoStoreConfig
cfg (SerialT ProtoStoreConfig e Get a -> Get (Either e a))
-> SerialT ProtoStoreConfig e Get a -> Get (Either e a)
forall a b. (a -> b) -> a -> b
$ NixSerializer ProtoStoreConfig e a
-> SerialT ProtoStoreConfig e Get a
forall (t :: (* -> *) -> * -> *) a. Serializer t a -> t Get a
Data.Serializer.getS NixSerializer ProtoStoreConfig e a
s

  case Either e a
res of
    Right a
x -> a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
    Left e
e -> e -> m a
forall a. e -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError e
e