-- | Description: Conduit pipelines for sending and receiving files and directories module Transit.Internal.Pipeline ( sendPipeline , receivePipeline -- * for tests , assembleRecordC , decryptC , encryptC ) where import Protolude import Crypto.Hash (SHA256(..)) import Data.Conduit ((.|)) import Data.ByteString.Builder(toLazyByteString, word32BE) import Data.Binary.Get (getWord32be, runGet) import qualified Crypto.Hash as Hash import qualified Conduit as C import qualified Data.Conduit.Network as CN import qualified Data.Conduit.Binary as CB import qualified Data.Binary.Builder as BB import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as BL import qualified Crypto.Saltine.Core.SecretBox as SecretBox import qualified Crypto.Saltine.Class as Saltine import Transit.Internal.Network (TCPEndpoint(..), TransitEndpoint(..)) import Transit.Internal.Crypto (encrypt, decrypt, PlainText(..), CipherText(..), CryptoError(..)) -- | Given the peer network socket and the file path to be sent, this Conduit -- pipeline reads the file, encrypts and send it over the network. A sha256 -- sum is calculated on the input before encryption to compare with the -- receiver's decrypted copy. sendPipeline :: C.MonadResource m => FilePath -> TransitEndpoint -> C.ConduitM a c m (Text, ()) sendPipeline fp (TransitEndpoint (TCPEndpoint s _) key _) = C.sourceFile fp .| sha256PassThroughC `C.fuseBoth` (encryptC key .| CN.sinkSocket s) -- | Receive the encrypted bytestream from a network socket, decrypt it and -- write it into a file, also calculating the sha256 sum of the decrypted -- output along the way. receivePipeline :: C.MonadResource m => FilePath -> Int -> TransitEndpoint -> C.ConduitM a c m (Text, ()) receivePipeline fp len (TransitEndpoint (TCPEndpoint s _) key _) = CN.sourceSocket s .| assembleRecordC .| decryptC key .| CB.isolate len .| sha256PassThroughC `C.fuseBoth` C.sinkFileCautious fp -- | A conduit function to encrypt the incoming byte stream with the given key encryptC :: MonadIO m => SecretBox.Key -> C.ConduitT ByteString ByteString m () encryptC key = loop Saltine.zero where loop nonce = do b <- C.await case b of Nothing -> return () Just chunk -> do let cipherText = encrypt key nonce (PlainText chunk) case cipherText of Right (CipherText cipherText') -> do let cipherTextSize = toLazyByteString (word32BE (fromIntegral (BS.length cipherText'))) C.yield (toS cipherTextSize) C.yield cipherText' loop (Saltine.nudge nonce) Left e -> throwIO e -- | A conduit function to decrypt the incoming byte stream with the given key decryptC :: MonadIO m => SecretBox.Key -> C.ConduitT ByteString ByteString m () decryptC key = loop Saltine.zero where loop :: MonadIO m => SecretBox.Nonce -> C.ConduitT ByteString ByteString m () loop seqNum = do b <- C.await case b of Nothing -> return () Just bs -> case decrypt key (CipherText bs) of Right (PlainText plainText, nonce) -> do let seqNumLE = BS.reverse $ toS $ Saltine.encode seqNum seqNum' = Saltine.decode (toS seqNumLE) if Just nonce /= seqNum' then throwIO (BadNonce "nonce decoding failed or packets received out of order.") else do C.yield plainText loop (Saltine.nudge seqNum) Left e -> throwIO e sha256PassThroughC :: (Monad m) => C.ConduitT ByteString ByteString m Text sha256PassThroughC = loop $! Hash.hashInitWith SHA256 where loop :: (Monad m) => Hash.Context SHA256 -> C.ConduitT ByteString ByteString m Text loop ctx = do b <- C.await case b of Nothing -> return $! show (Hash.hashFinalize ctx) Just bs -> do C.yield bs loop $! Hash.hashUpdate ctx bs -- | The decryption conduit computation would succeed only if a complete -- bytestream that represents an encrypted block of data is given to it. -- However, the upstream elements may chunk the data for which one may not -- have control of. The encrypted packet on the wire has a 4-byte length -- header, so we could first read it and assemble a complete encrypted -- block into downstream. assembleRecordC :: Monad m => C.ConduitT ByteString ByteString m () assembleRecordC = do hdr <- getChunk 4 let len = runGet getWord32be (BL.fromStrict hdr) packet <- getChunk (fromIntegral len) C.yield packet assembleRecordC where getChunk :: Monad m => Int -> C.ConduitT ByteString ByteString m ByteString getChunk size = go size BB.empty go :: Monad m => Int -> BB.Builder -> C.ConduitT ByteString ByteString m ByteString go size res = do let residue = BL.toStrict . BB.toLazyByteString $ res b <- C.await case b of Nothing -> return residue Just bs | size < BS.length bs -> do let (f, l) = BS.splitAt size bs C.leftover l return $ residue <> f | size == BS.length bs -> return (residue <> bs) | otherwise -> do let want = size - BS.length bs go want $ BB.fromByteString (residue <> bs)