{-# LANGUAGE LambdaCase, RankNTypes #-} module Network.SSH.Encoding where import Control.Applicative import Control.Monad ( when ) import qualified Control.Monad.Fail as Fail import qualified Data.ByteArray as BA import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import qualified Data.Serialize.Get as G import Data.Word import System.Exit import qualified Network.SSH.Builder as B import Network.SSH.Name type Get = G.Get class Encoding a where put :: forall b. B.Builder b => a -> b get :: Get a runPut :: B.ByteArrayBuilder -> BS.ByteString runPut = B.toByteArray {-# INLINEABLE runPut #-} runGet :: (Fail.MonadFail m, Encoding a) => BS.ByteString -> m a runGet bs = case G.runGet get bs of Left e -> Fail.fail e Right a -> pure a {-# INLINEABLE runGet #-} putExitCode :: B.Builder b => ExitCode -> b putExitCode = \case ExitSuccess -> B.word32BE 0 ExitFailure x -> B.word32BE (fromIntegral x) {-# INLINEABLE putExitCode #-} getExitCode :: Get ExitCode getExitCode = getWord32 >>= \case 0 -> pure ExitSuccess x -> pure (ExitFailure $ fromIntegral x) {-# INLINEABLE getExitCode #-} getFramed :: Get a -> Get a getFramed g = do w <- getWord32 G.isolate (fromIntegral w) g {-# INLINEABLE getFramed #-} putWord8 :: B.Builder b => Word8 -> b putWord8 = B.word8 {-# INLINEABLE putWord8 #-} getWord8 :: Get Word8 getWord8 = G.getWord8 {-# INLINEABLE getWord8 #-} expectWord8 :: Word8 -> Get () expectWord8 i = do i' <- getWord8 when (i /= i') (fail mempty) {-# INLINEABLE expectWord8 #-} getWord32 :: Get Word32 getWord32 = G.getWord32be {-# INLINEABLE getWord32 #-} putBytes :: B.Builder b => BA.ByteArrayAccess ba => ba -> b putBytes = B.byteArray {-# INLINEABLE putBytes #-} getBytes :: BA.ByteArray ba => Word32 -> Get ba getBytes i = BA.convert <$> G.getByteString (fromIntegral i) {-# INLINEABLE getBytes #-} lenByteString :: BS.ByteString -> Word32 lenByteString = fromIntegral . BA.length {-# INLINEABLE lenByteString #-} putByteString :: B.Builder b => BS.ByteString -> b putByteString = B.byteString {-# INLINEABLE putByteString #-} getByteString :: Word32 -> Get BS.ByteString getByteString = G.getByteString . fromIntegral {-# INLINEABLE getByteString #-} getRemainingByteString :: Get BS.ByteString getRemainingByteString = G.remaining >>= G.getBytes {-# INLINEABLE getRemainingByteString #-} putString :: (B.Builder b, BA.ByteArrayAccess ba) => ba -> b putString ba = B.word32BE (fromIntegral $ BA.length ba) <> putBytes ba {-# INLINEABLE putString #-} putShortString :: B.Builder b => SBS.ShortByteString -> b putShortString bs = B.word32BE (fromIntegral $ SBS.length bs) <> B.shortByteString bs {-# INLINEABLE putShortString #-} getShortString :: Get SBS.ShortByteString getShortString = SBS.toShort <$> getString {-# INLINEABLE getShortString #-} getString :: BA.ByteArray ba => Get ba getString = getWord32 >>= getBytes {-# INLINEABLE getString #-} getName :: Get Name getName = Name <$> getShortString {-# INLINEABLE getName #-} putName :: B.Builder b => Name -> b putName (Name n) = putShortString n {-# INLINEABLE putName #-} putBool :: B.Builder b => Bool -> b putBool False = putWord8 0 putBool True = putWord8 1 {-# INLINEABLE putBool #-} getBool :: Get Bool getBool = (expectWord8 0 >> pure False) <|> (expectWord8 1 >> pure True) {-# INLINEABLE getBool #-} getTrue :: Get () getTrue = expectWord8 1 {-# INLINEABLE getTrue #-} getFalse :: Get () getFalse = expectWord8 0 {-# INLINEABLE getFalse #-} putAsMPInt :: (B.Builder b, BA.ByteArrayAccess ba) => ba -> b putAsMPInt ba = f 0 where baLen = BA.length ba f i | i >= baLen = mempty | BA.index ba i == 0 = f (i + 1) | BA.index ba i >= 128 = B.word32BE (fromIntegral $ baLen - i + 1) <> putWord8 0 <> putWord8 (BA.index ba i) <> g (i + 1) | otherwise = B.word32BE (fromIntegral $ baLen - i) <> putWord8 (BA.index ba i) <> g (i + 1) g i | i >= baLen = mempty | otherwise = putWord8 (BA.index ba i) <> g (i + 1)