{-# LANGUAGE BangPatterns #-}
module Crypto.RNCryptor.V3.Stream
  ( processStream
  , StreamingState(..)
  ) where

import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.Word
import           Control.Monad.State
import           Crypto.RNCryptor.Types
import           Data.Monoid
import qualified System.IO.Streams as S

--------------------------------------------------------------------------------
-- | The 'StreamingState' the streamer can be at. This is needed to drive the
-- computation as well as reading leftovers unread back in case we need to
-- chop the buffer read, if not multiple of the 'blockSize'.
data StreamingState =
    Continue
  | FetchLeftOver !Int
  | DrainSource deriving (Int -> StreamingState -> ShowS
[StreamingState] -> ShowS
StreamingState -> String
(Int -> StreamingState -> ShowS)
-> (StreamingState -> String)
-> ([StreamingState] -> ShowS)
-> Show StreamingState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StreamingState] -> ShowS
$cshowList :: [StreamingState] -> ShowS
show :: StreamingState -> String
$cshow :: StreamingState -> String
showsPrec :: Int -> StreamingState -> ShowS
$cshowsPrec :: Int -> StreamingState -> ShowS
Show, StreamingState -> StreamingState -> Bool
(StreamingState -> StreamingState -> Bool)
-> (StreamingState -> StreamingState -> Bool) -> Eq StreamingState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StreamingState -> StreamingState -> Bool
$c/= :: StreamingState -> StreamingState -> Bool
== :: StreamingState -> StreamingState -> Bool
$c== :: StreamingState -> StreamingState -> Bool
Eq)

--------------------------------------------------------------------------------
-- | Efficiently transform an incoming stream of bytes.
processStream :: RNCryptorContext
              -- ^ The RNCryptor context for this operation
              -> S.InputStream ByteString
              -- ^ The input source (mostly likely stdin)
              -> S.OutputStream ByteString
              -- ^ The output source (mostly likely stdout)
              -> (RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString))
              -- ^ The action to perform over the block
              -> (ByteString -> RNCryptorContext -> IO ())
              -- ^ The finaliser
              -> IO ()
processStream :: RNCryptorContext
-> InputStream ByteString
-> OutputStream ByteString
-> (RNCryptorContext
    -> ByteString -> (RNCryptorContext, ByteString))
-> (ByteString -> RNCryptorContext -> IO ())
-> IO ()
processStream RNCryptorContext
context InputStream ByteString
inS OutputStream ByteString
outS RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
blockFn ByteString -> RNCryptorContext -> IO ()
finaliser = StreamingState -> ByteString -> RNCryptorContext -> IO ()
go StreamingState
Continue ByteString
forall a. Monoid a => a
mempty RNCryptorContext
context
  where
    slack :: ByteString -> (Int, Int)
slack ByteString
input = let bsL :: Int
bsL = ByteString -> Int
B.length ByteString
input in (Int
bsL, Int
bsL Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
blockSize)

    go :: StreamingState -> ByteString -> RNCryptorContext -> IO ()
    go :: StreamingState -> ByteString -> RNCryptorContext -> IO ()
go StreamingState
dc !ByteString
iBuffer RNCryptorContext
ctx = do
      Maybe ByteString
nextChunk <- case StreamingState
dc of
        FetchLeftOver Int
size -> do
          ByteString
lo <- Int -> InputStream ByteString -> IO ByteString
S.readExactly Int
size InputStream ByteString
inS
          Maybe ByteString
p  <- InputStream ByteString -> IO (Maybe ByteString)
forall a. InputStream a -> IO (Maybe a)
S.read InputStream ByteString
inS
          Maybe ByteString -> IO (Maybe ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ByteString -> IO (Maybe ByteString))
-> Maybe ByteString -> IO (Maybe ByteString)
forall a b. (a -> b) -> a -> b
$ (ByteString -> ByteString) -> Maybe ByteString -> Maybe ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ByteString -> ByteString -> ByteString
forall a. Monoid a => a -> a -> a
mappend ByteString
lo) Maybe ByteString
p
        StreamingState
_ -> InputStream ByteString -> IO (Maybe ByteString)
forall a. InputStream a -> IO (Maybe a)
S.read InputStream ByteString
inS
      case Maybe ByteString
nextChunk of
        Maybe ByteString
Nothing -> ByteString -> RNCryptorContext -> IO ()
finaliser ByteString
iBuffer RNCryptorContext
ctx
        (Just ByteString
v) -> do
          let (Int
sz, Int
sl) = ByteString -> (Int, Int)
slack ByteString
v
          case StreamingState
dc of
            StreamingState
DrainSource -> StreamingState -> ByteString -> RNCryptorContext -> IO ()
go StreamingState
DrainSource (ByteString
iBuffer ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
v) RNCryptorContext
ctx
            StreamingState
_ -> do
              Maybe ByteString
whatsNext <- InputStream ByteString -> IO (Maybe ByteString)
forall a. InputStream a -> IO (Maybe a)
S.peek InputStream ByteString
inS
              case Maybe ByteString
whatsNext of
                Maybe ByteString
Nothing -> ByteString -> RNCryptorContext -> IO ()
finaliser (ByteString
iBuffer ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
v) RNCryptorContext
ctx
                Just ByteString
nt ->
                  case Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
nt Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4096 of
                    Bool
True  -> StreamingState -> ByteString -> RNCryptorContext -> IO ()
go StreamingState
DrainSource (ByteString
iBuffer ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
v) RNCryptorContext
ctx
                    Bool
False -> do
                      -- If I'm here, it means I can safely process this chunk
                      let (ByteString
toProcess, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sl) ByteString
v
                      let (RNCryptorContext
newCtx, ByteString
res) = RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
blockFn RNCryptorContext
ctx ByteString
toProcess
                      Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
S.write (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
res) OutputStream ByteString
outS
                      case Int
sl Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 of
                        Bool
False -> do
                          ByteString -> InputStream ByteString -> IO ()
forall a. a -> InputStream a -> IO ()
S.unRead ByteString
rest InputStream ByteString
inS
                          StreamingState -> ByteString -> RNCryptorContext -> IO ()
go (Int -> StreamingState
FetchLeftOver Int
sl) ByteString
iBuffer RNCryptorContext
newCtx
                        Bool
True -> StreamingState -> ByteString -> RNCryptorContext -> IO ()
go StreamingState
Continue ByteString
iBuffer RNCryptorContext
newCtx