{-# LANGUAGE DataKinds, TypeFamilies #-}

{- | Low-level patchscript processing and application.

Patchscripts are applied as a list of @(skip x, write in-place y)@ commands. An
offset-based format is much simpler to use, however. This module processes such
offset patchscripts into a "linear" patchscript, and provides a stream patching
algorithm that can be applied to any forward-seeking byte stream.

Some core types are parameterized over the stream type/patch content. This
enables writing patches in any form (e.g. UTF-8 text), which are then processed
into an applicable patch by transforming edits into a concrete binary
representation (e.g. null-terminated UTF-8 bytestring). See TODO module for
more.
-}

module BytePatch.Linear.Patch
  (
  -- * Patch interface
    MonadFwdByteStream(..)
  , Cfg(..)
  , Error(..)

  -- * Prepared patchers
  , patchPure

  -- * General patcher
  , patch

  ) where

import           BytePatch.Core

import           GHC.Natural
import qualified Data.ByteString         as BS
import qualified Data.ByteString.Lazy    as BL
import qualified Data.ByteString.Builder as BB
import           Control.Monad.State
import           Control.Monad.Reader
import           System.IO               ( Handle, SeekMode(..), hSeek )
import           Optics

type Bytes = BS.ByteString

-- TODO also require reporting cursor position (for error reporting)
class Monad m => MonadFwdByteStream m where
    -- | Read a number of bytes without advancing the cursor.
    readahead :: Natural -> m Bytes

    -- | Advance cursor without reading.
    advance :: Natural -> m ()

    -- | Insert bytes into the stream at the cursor position, overwriting
    --   existing bytes.
    overwrite :: Bytes -> m ()

instance Monad m => MonadFwdByteStream (StateT (Bytes, BB.Builder) m) where
    readahead :: Natural -> StateT (ByteString, Builder) m ByteString
readahead Natural
n = Int -> ByteString -> ByteString
BS.take (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n) (ByteString -> ByteString)
-> StateT (ByteString, Builder) m ByteString
-> StateT (ByteString, Builder) m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((ByteString, Builder) -> ByteString)
-> StateT (ByteString, Builder) m ByteString
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (ByteString, Builder) -> ByteString
forall a b. (a, b) -> a
fst
    advance :: Natural -> StateT (ByteString, Builder) m ()
advance Natural
n = do
        (ByteString
src, Builder
out) <- StateT (ByteString, Builder) m (ByteString, Builder)
forall s (m :: * -> *). MonadState s m => m s
get
        let (ByteString
bs, ByteString
src') = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n) ByteString
src
        (ByteString, Builder) -> StateT (ByteString, Builder) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
bs)
    overwrite :: ByteString -> StateT (ByteString, Builder) m ()
overwrite ByteString
bs = do
        (ByteString
src, Builder
out) <- StateT (ByteString, Builder) m (ByteString, Builder)
forall s (m :: * -> *). MonadState s m => m s
get
        let (ByteString
_, ByteString
src') = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt (ByteString -> Int
BS.length ByteString
bs) ByteString
src
        (ByteString, Builder) -> StateT (ByteString, Builder) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (ByteString
src', Builder
out Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
bs)

instance MonadIO m => MonadFwdByteStream (ReaderT Handle m) where
    readahead :: Natural -> ReaderT Handle m ByteString
readahead Natural
n = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        ByteString
bs <- IO ByteString -> ReaderT Handle m ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ReaderT Handle m ByteString)
-> IO ByteString -> ReaderT Handle m ByteString
forall a b. (a -> b) -> a -> b
$ Handle -> Int -> IO ByteString
BS.hGet Handle
hdl (Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n)
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
hSeek Handle
hdl SeekMode
RelativeSeek (- Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n)
        ByteString -> ReaderT Handle m ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
    advance :: Natural -> ReaderT Handle m ()
advance Natural
n = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> SeekMode -> Integer -> IO ()
hSeek Handle
hdl SeekMode
RelativeSeek (Natural -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n)
    overwrite :: ByteString -> ReaderT Handle m ()
overwrite ByteString
bs = do
        Handle
hdl <- ReaderT Handle m Handle
forall r (m :: * -> *). MonadReader r m => m r
ask
        IO () -> ReaderT Handle m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ReaderT Handle m ()) -> IO () -> ReaderT Handle m ()
forall a b. (a -> b) -> a -> b
$ Handle -> ByteString -> IO ()
BS.hPut Handle
hdl ByteString
bs

-- | Patch time config.
data Cfg = Cfg
  { Cfg -> Bool
cfgWarnIfLikelyReprocessing :: Bool
  -- ^ If we determine that we're repatching an already-patched stream, continue
  --   with a warning instead of failing.

  , Cfg -> Bool
cfgAllowPartialExpected :: Bool
  -- ^ If enabled, allow partial expected bytes checking. If disabled, then even
  --   if the expected bytes are a prefix of the actual, fail.
  } deriving (Cfg -> Cfg -> Bool
(Cfg -> Cfg -> Bool) -> (Cfg -> Cfg -> Bool) -> Eq Cfg
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Cfg -> Cfg -> Bool
$c/= :: Cfg -> Cfg -> Bool
== :: Cfg -> Cfg -> Bool
$c== :: Cfg -> Cfg -> Bool
Eq, Int -> Cfg -> ShowS
[Cfg] -> ShowS
Cfg -> String
(Int -> Cfg -> ShowS)
-> (Cfg -> String) -> ([Cfg] -> ShowS) -> Show Cfg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Cfg] -> ShowS
$cshowList :: [Cfg] -> ShowS
show :: Cfg -> String
$cshow :: Cfg -> String
showsPrec :: Int -> Cfg -> ShowS
$cshowsPrec :: Int -> Cfg -> ShowS
Show)

-- | Errors encountered during patch time.
data Error
  = ErrorPatchOverlong
  | ErrorPatchUnexpectedNonnull
  | ErrorPatchDidNotMatchExpected Bytes Bytes
    deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)

patch
    :: MonadFwdByteStream m
    => Cfg -> [Patch 'FwdSeek Bytes]
    -> m (Maybe Error)
patch :: forall (m :: * -> *).
MonadFwdByteStream m =>
Cfg -> [Patch 'FwdSeek ByteString] -> m (Maybe Error)
patch Cfg
cfg = [Patch 'FwdSeek ByteString] -> m (Maybe Error)
go
  where
    go :: [Patch 'FwdSeek ByteString] -> m (Maybe Error)
go [] = Maybe Error -> m (Maybe Error)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Error
forall a. Maybe a
Nothing
    go (Patch SeekRep 'FwdSeek
n (Edit ByteString
bs EditMeta ByteString
meta):[Patch 'FwdSeek ByteString]
es) = do
        Natural -> m ()
forall (m :: * -> *). MonadFwdByteStream m => Natural -> m ()
advance Natural
SeekRep 'FwdSeek
n
        ByteString
bsStream <- Natural -> m ByteString
forall (m :: * -> *).
MonadFwdByteStream m =>
Natural -> m ByteString
readahead (Natural -> m ByteString) -> Natural -> m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Natural) -> Int -> Natural
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
bs -- TODO catch overlong error

        -- if provided, strip trailing nulls from to-overwrite bytestring
        case ByteString -> Maybe Int -> Maybe ByteString
tryStripNulls ByteString
bsStream (EditMeta ByteString -> Maybe Int
forall a. EditMeta a -> Maybe Int
emNullTerminates EditMeta ByteString
meta) of
          Maybe ByteString
Nothing -> Maybe Error -> m (Maybe Error)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Error -> m (Maybe Error)) -> Maybe Error -> m (Maybe Error)
forall a b. (a -> b) -> a -> b
$ Error -> Maybe Error
forall a. a -> Maybe a
Just Error
ErrorPatchUnexpectedNonnull
          Just ByteString
bsStream' -> do

            -- if provided, check the to-overwrite bytestring matches expected
            case ByteString -> Maybe ByteString -> Maybe (ByteString, ByteString)
checkExpected ByteString
bsStream' (EditMeta ByteString -> Maybe ByteString
forall a. EditMeta a -> Maybe a
emExpected EditMeta ByteString
meta) of
              Just (ByteString
bsa, ByteString
bse) -> Maybe Error -> m (Maybe Error)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Error -> m (Maybe Error)) -> Maybe Error -> m (Maybe Error)
forall a b. (a -> b) -> a -> b
$ Error -> Maybe Error
forall a. a -> Maybe a
Just (Error -> Maybe Error) -> Error -> Maybe Error
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> Error
ErrorPatchDidNotMatchExpected ByteString
bsa ByteString
bse
              Maybe (ByteString, ByteString)
Nothing -> ByteString -> m ()
forall (m :: * -> *). MonadFwdByteStream m => ByteString -> m ()
overwrite ByteString
bs m () -> m (Maybe Error) -> m (Maybe Error)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Patch 'FwdSeek ByteString] -> m (Maybe Error)
go [Patch 'FwdSeek ByteString]
es

    tryStripNulls :: ByteString -> Maybe Int -> Maybe ByteString
tryStripNulls ByteString
bs = \case
      Maybe Int
Nothing        -> ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs
      Just Int
nullsFrom ->
        let (ByteString
bs', ByteString
bsNulls) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
nullsFrom ByteString
bs
         in if   ByteString
bsNulls ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> Word8 -> ByteString
BS.replicate (ByteString -> Int
BS.length ByteString
bsNulls) Word8
0x00
            then ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
bs'
            else Maybe ByteString
forall a. Maybe a
Nothing

    checkExpected :: ByteString -> Maybe ByteString -> Maybe (ByteString, ByteString)
checkExpected ByteString
bs = \case
      Maybe ByteString
Nothing -> Maybe (ByteString, ByteString)
forall a. Maybe a
Nothing
      Just ByteString
bsExpected ->
        case Cfg -> Bool
cfgAllowPartialExpected Cfg
cfg of
          Bool
True  -> if   ByteString -> ByteString -> Bool
BS.isPrefixOf ByteString
bs ByteString
bsExpected
                   then Maybe (ByteString, ByteString)
forall a. Maybe a
Nothing
                   else (ByteString, ByteString) -> Maybe (ByteString, ByteString)
forall a. a -> Maybe a
Just (ByteString
bs, ByteString
bsExpected)
          Bool
False -> if   ByteString
bs ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
bsExpected
                   then Maybe (ByteString, ByteString)
forall a. Maybe a
Nothing
                   else (ByteString, ByteString) -> Maybe (ByteString, ByteString)
forall a. a -> Maybe a
Just (ByteString
bs, ByteString
bsExpected)

-- | Attempt to apply a patchscript to a 'Data.ByteString.ByteString'.
patchPure :: Cfg -> [Patch 'FwdSeek Bytes] -> BS.ByteString -> Either Error BL.ByteString
patchPure :: Cfg
-> [Patch 'FwdSeek ByteString]
-> ByteString
-> Either Error ByteString
patchPure Cfg
cfg [Patch 'FwdSeek ByteString]
ps ByteString
bs =
    let (Maybe Error
mErr, (ByteString
bsRemaining, Builder
bbPatched)) = State (ByteString, Builder) (Maybe Error)
-> (ByteString, Builder) -> (Maybe Error, (ByteString, Builder))
forall s a. State s a -> s -> (a, s)
runState (Cfg
-> [Patch 'FwdSeek ByteString]
-> State (ByteString, Builder) (Maybe Error)
forall (m :: * -> *).
MonadFwdByteStream m =>
Cfg -> [Patch 'FwdSeek ByteString] -> m (Maybe Error)
patch Cfg
cfg [Patch 'FwdSeek ByteString]
ps) (ByteString
bs, Builder
forall a. Monoid a => a
mempty)
        bbPatched' :: Builder
bbPatched' = Builder
bbPatched Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
BB.byteString ByteString
bsRemaining
     in case Maybe Error
mErr of
          Just Error
err -> Error -> Either Error ByteString
forall a b. a -> Either a b
Left Error
err
          Maybe Error
Nothing  -> ByteString -> Either Error ByteString
forall a b. b -> Either a b
Right (ByteString -> Either Error ByteString)
-> ByteString -> Either Error ByteString
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
BB.toLazyByteString Builder
bbPatched'