module Transit.Internal.Pipeline
  ( sendPipeline
  , receivePipeline
  
  , assembleRecordC
  , decryptC
  , encryptC
  )
where
import Protolude
import Crypto.Hash (SHA256(..))
import Data.Conduit ((.|))
import Data.ByteString.Builder(toLazyByteString, word32BE)
import Data.Binary.Get (getWord32be, runGet)
import qualified Crypto.Hash as Hash
import qualified Conduit as C
import qualified Data.Conduit.Network as CN
import qualified Data.Conduit.Binary as CB
import qualified Data.Binary.Builder as BB
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import qualified Crypto.Saltine.Core.SecretBox as SecretBox
import qualified Crypto.Saltine.Class as Saltine
import Transit.Internal.Network (TCPEndpoint(..), TransitEndpoint(..))
import Transit.Internal.Crypto (encrypt, decrypt, PlainText(..), CipherText(..), CryptoError(..))
sendPipeline :: C.MonadResource m =>
                FilePath
             -> TransitEndpoint
             -> C.ConduitM a c m (Text, ())
sendPipeline fp (TransitEndpoint (TCPEndpoint s _) key _) =
  C.sourceFile fp .| sha256PassThroughC `C.fuseBoth` (encryptC key .| CN.sinkSocket s)
receivePipeline :: C.MonadResource m =>
                   FilePath
                -> Int
                -> TransitEndpoint
                -> C.ConduitM a c m (Text, ())
receivePipeline fp len (TransitEndpoint (TCPEndpoint s _) key _) =
    CN.sourceSocket s
    .| assembleRecordC
    .| decryptC key
    .| CB.isolate len
    .| sha256PassThroughC `C.fuseBoth` C.sinkFileCautious fp
encryptC :: MonadIO m => SecretBox.Key -> C.ConduitT ByteString ByteString m ()
encryptC key = loop Saltine.zero
  where
    loop nonce = do
      b <- C.await
      case b of
        Nothing -> return ()
        Just chunk -> do
          let cipherText = encrypt key nonce (PlainText chunk)
          case cipherText of
            Right (CipherText cipherText') -> do
              let cipherTextSize = toLazyByteString (word32BE (fromIntegral (BS.length cipherText')))
              C.yield (toS cipherTextSize)
              C.yield cipherText'
              loop (Saltine.nudge nonce)
            Left e -> throwIO e
decryptC :: MonadIO m => SecretBox.Key -> C.ConduitT ByteString ByteString m ()
decryptC key = loop Saltine.zero
  where
    loop :: MonadIO m => SecretBox.Nonce -> C.ConduitT ByteString ByteString m ()
    loop seqNum = do
      b <- C.await
      case b of
        Nothing -> return ()
        Just bs ->
          case decrypt key (CipherText bs) of
            Right (PlainText plainText, nonce) -> do
              let seqNumLE = BS.reverse $ toS $ Saltine.encode seqNum
                  seqNum' = Saltine.decode (toS seqNumLE)
              if Just nonce /= seqNum'
                then throwIO (BadNonce "nonce decoding failed or packets received out of order.")
                else do
                C.yield plainText
                loop (Saltine.nudge seqNum)
            Left e -> throwIO e
sha256PassThroughC :: (Monad m) => C.ConduitT ByteString ByteString m Text
sha256PassThroughC = loop $! Hash.hashInitWith SHA256
  where
    loop :: (Monad m) => Hash.Context SHA256 -> C.ConduitT ByteString ByteString m Text
    loop ctx = do
      b <- C.await
      case b of
        Nothing -> return $! show (Hash.hashFinalize ctx)
        Just bs -> do
          C.yield bs
          loop $! Hash.hashUpdate ctx bs
assembleRecordC :: Monad m => C.ConduitT ByteString ByteString m ()
assembleRecordC = do
  hdr <- getChunk 4
  let len = runGet getWord32be (BL.fromStrict hdr)
  packet <- getChunk (fromIntegral len)
  C.yield packet
  assembleRecordC
  where
    getChunk :: Monad m => Int -> C.ConduitT ByteString ByteString m ByteString
    getChunk size = go size BB.empty
    go :: Monad m => Int -> BB.Builder -> C.ConduitT ByteString ByteString m ByteString
    go size res = do
      let residue = BL.toStrict . BB.toLazyByteString $ res
      b <- C.await
      case b of
        Nothing -> return residue
        Just bs | size < BS.length bs -> do
                    let (f, l) = BS.splitAt size bs
                    C.leftover l
                    return $ residue <> f
                | size == BS.length bs -> return (residue <> bs)
                | otherwise -> do
                    let want = size - BS.length bs
                    go want $ BB.fromByteString (residue <> bs)