{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      : Simplex.Messaging.Transport
-- Copyright   : (c) simplex.chat
-- License     : AGPL-3
--
-- Maintainer  : chat@simplex.chat
-- Stability   : experimental
-- Portability : non-portable
--
-- This module defines basic TCP server and client and SMP protocol encrypted transport over TCP.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
module Simplex.Messaging.Transport
  ( -- * Transport connection class
    Transport (..),
    TProxy (..),
    ATransport (..),

    -- * Transport over TCP
    runTransportServer,
    runTransportClient,

    -- * TCP transport
    TCP (..),

    -- * SMP encrypted transport
    THandle (..),
    TransportError (..),
    serverHandshake,
    clientHandshake,
    tPutEncrypted,
    tGetEncrypted,
    serializeTransportError,
    transportErrorP,

    -- * Trim trailing CR
    trimCR,
  )
where

import Control.Applicative ((<|>))
import Control.Monad.Except
import Control.Monad.IO.Unlift
import Control.Monad.Trans.Except (throwE)
import Crypto.Cipher.Types (AuthTag)
import Data.Attoparsec.ByteString.Char8 (Parser)
import qualified Data.Attoparsec.ByteString.Char8 as A
import Data.Bifunctor (first)
import Data.ByteArray (xor)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import Data.Functor (($>))
import Data.Set (Set)
import qualified Data.Set as S
import Data.Word (Word32)
import GHC.Generics (Generic)
import GHC.IO.Exception (IOErrorType (..))
import GHC.IO.Handle.Internals (ioe_EOF)
import Generic.Random (genericArbitraryU)
import Network.Socket
import Network.Transport.Internal (decodeNum16, decodeNum32, encodeEnum16, encodeEnum32, encodeWord32)
import qualified Simplex.Messaging.Crypto as C
import Simplex.Messaging.Parsers (parse, parseAll, parseRead1)
import Simplex.Messaging.Util (bshow, liftError)
import System.IO
import System.IO.Error
import Test.QuickCheck (Arbitrary (..))
import UnliftIO.Concurrent
import UnliftIO.Exception (Exception, IOException)
import qualified UnliftIO.Exception as E
import UnliftIO.STM

-- * Transport connection class

class Transport c where
  transport :: ATransport
  transport = TProxy c -> ATransport
forall c. Transport c => TProxy c -> ATransport
ATransport (TProxy c
forall c. TProxy c
TProxy @c)

  transportName :: TProxy c -> String

  -- | Upgrade client socket to connection (used in the server)
  getServerConnection :: Socket -> IO c

  -- | Upgrade server socket to connection (used in the client)
  getClientConnection :: Socket -> IO c

  -- | Close connection
  closeConnection :: c -> IO ()

  -- | Read fixed number of bytes from connection
  cGet :: c -> Int -> IO ByteString

  -- | Write bytes to connection
  cPut :: c -> ByteString -> IO ()

  -- | Receive ByteString from connection, allowing LF or CRLF termination.
  getLn :: c -> IO ByteString

  -- | Send ByteString to connection terminating it with CRLF.
  putLn :: c -> ByteString -> IO ()
  putLn c
c = c -> ByteString -> IO ()
forall c. Transport c => c -> ByteString -> IO ()
cPut c
c (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"\r\n")

data TProxy c = TProxy

data ATransport = forall c. Transport c => ATransport (TProxy c)

-- * Transport over TCP

-- | Run transport server (plain TCP or WebSockets) on passed TCP port and signal when server started and stopped via passed TMVar.
--
-- All accepted connections are passed to the passed function.
runTransportServer :: (Transport c, MonadUnliftIO m) => TMVar Bool -> ServiceName -> (c -> m ()) -> m ()
runTransportServer :: TMVar Bool -> ServiceName -> (c -> m ()) -> m ()
runTransportServer TMVar Bool
started ServiceName
port c -> m ()
server = do
  TVar (Set ThreadId)
clients <- Set ThreadId -> m (TVar (Set ThreadId))
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Set ThreadId
forall a. Set a
S.empty
  m Socket -> (Socket -> m ()) -> (Socket -> m ()) -> m ()
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket (IO Socket -> m Socket
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Socket -> m Socket) -> IO Socket -> m Socket
forall a b. (a -> b) -> a -> b
$ TMVar Bool -> ServiceName -> IO Socket
startTCPServer TMVar Bool
started ServiceName
port) (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Socket -> IO ()) -> Socket -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (Set ThreadId) -> Socket -> IO ()
closeServer TVar (Set ThreadId)
clients) \Socket
sock -> m () -> m ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    c
c <- IO c -> m c
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO c -> m c) -> IO c -> m c
forall a b. (a -> b) -> a -> b
$ Socket -> IO c
forall c. Transport c => Socket -> IO c
acceptConnection Socket
sock
    ThreadId
tid <- m () -> (Either SomeException () -> m ()) -> m ThreadId
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> (Either SomeException a -> m ()) -> m ThreadId
forkFinally (c -> m ()
server c
c) (m () -> Either SomeException () -> m ()
forall a b. a -> b -> a
const (m () -> Either SomeException () -> m ())
-> m () -> Either SomeException () -> m ()
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ c -> IO ()
forall c. Transport c => c -> IO ()
closeConnection c
c)
    STM () -> m ()
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM () -> m ())
-> ((Set ThreadId -> Set ThreadId) -> STM ())
-> (Set ThreadId -> Set ThreadId)
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar (Set ThreadId) -> (Set ThreadId -> Set ThreadId) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (Set ThreadId)
clients ((Set ThreadId -> Set ThreadId) -> m ())
-> (Set ThreadId -> Set ThreadId) -> m ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> Set ThreadId -> Set ThreadId
forall a. Ord a => a -> Set a -> Set a
S.insert ThreadId
tid
  where
    closeServer :: TVar (Set ThreadId) -> Socket -> IO ()
    closeServer :: TVar (Set ThreadId) -> Socket -> IO ()
closeServer TVar (Set ThreadId)
clients Socket
sock = do
      TVar (Set ThreadId) -> IO (Set ThreadId)
forall (m :: * -> *) a. MonadIO m => TVar a -> m a
readTVarIO TVar (Set ThreadId)
clients IO (Set ThreadId) -> (Set ThreadId -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ThreadId -> IO ()) -> Set ThreadId -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ThreadId -> IO ()
forall (m :: * -> *). MonadIO m => ThreadId -> m ()
killThread
      Socket -> IO ()
close Socket
sock
      IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> (STM Bool -> IO Bool) -> STM Bool -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM Bool -> IO ()) -> STM Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar Bool -> Bool -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar Bool
started Bool
False
    acceptConnection :: Transport c => Socket -> IO c
    acceptConnection :: Socket -> IO c
acceptConnection Socket
sock = Socket -> IO (Socket, SockAddr)
accept Socket
sock IO (Socket, SockAddr) -> ((Socket, SockAddr) -> IO c) -> IO c
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> IO c
forall c. Transport c => Socket -> IO c
getServerConnection (Socket -> IO c)
-> ((Socket, SockAddr) -> Socket) -> (Socket, SockAddr) -> IO c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket, SockAddr) -> Socket
forall a b. (a, b) -> a
fst

startTCPServer :: TMVar Bool -> ServiceName -> IO Socket
startTCPServer :: TMVar Bool -> ServiceName -> IO Socket
startTCPServer TMVar Bool
started ServiceName
port = IO Socket -> IO Socket
forall a. IO a -> IO a
withSocketsDo (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ IO AddrInfo
resolve IO AddrInfo -> (AddrInfo -> IO Socket) -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= AddrInfo -> IO Socket
open IO Socket -> (Socket -> IO Socket) -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> IO Socket
setStarted
  where
    resolve :: IO AddrInfo
resolve =
      let hints :: AddrInfo
hints = AddrInfo
defaultHints {addrFlags :: [AddrInfoFlag]
addrFlags = [AddrInfoFlag
AI_PASSIVE], addrSocketType :: SocketType
addrSocketType = SocketType
Stream}
       in [AddrInfo] -> AddrInfo
forall a. [a] -> a
head ([AddrInfo] -> AddrInfo) -> IO [AddrInfo] -> IO AddrInfo
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe AddrInfo
-> Maybe ServiceName -> Maybe ServiceName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) Maybe ServiceName
forall a. Maybe a
Nothing (ServiceName -> Maybe ServiceName
forall a. a -> Maybe a
Just ServiceName
port)
    open :: AddrInfo -> IO Socket
open AddrInfo
addr = do
      Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr)
      Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
      Socket -> (ProtocolNumber -> IO ()) -> IO ()
forall r. Socket -> (ProtocolNumber -> IO r) -> IO r
withFdSocket Socket
sock ProtocolNumber -> IO ()
setCloseOnExecIfNeeded
      Socket -> SockAddr -> IO ()
bind Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
      Socket -> Int -> IO ()
listen Socket
sock Int
1024
      Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
    setStarted :: Socket -> IO Socket
setStarted Socket
sock = STM Bool -> IO Bool
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (TMVar Bool -> Bool -> STM Bool
forall a. TMVar a -> a -> STM Bool
tryPutTMVar TMVar Bool
started Bool
True) IO Bool -> IO Socket -> IO Socket
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Socket -> IO Socket
forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
sock

-- | Connect to passed TCP host:port and pass handle to the client.
runTransportClient :: Transport c => MonadUnliftIO m => HostName -> ServiceName -> (c -> m a) -> m a
runTransportClient :: ServiceName -> ServiceName -> (c -> m a) -> m a
runTransportClient ServiceName
host ServiceName
port c -> m a
client = do
  c
c <- IO c -> m c
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO c -> m c) -> IO c -> m c
forall a b. (a -> b) -> a -> b
$ ServiceName -> ServiceName -> IO c
forall c. Transport c => ServiceName -> ServiceName -> IO c
startTCPClient ServiceName
host ServiceName
port
  c -> m a
client c
c m a -> m () -> m a
forall (m :: * -> *) a b. MonadUnliftIO m => m a -> m b -> m a
`E.finally` IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (c -> IO ()
forall c. Transport c => c -> IO ()
closeConnection c
c)

startTCPClient :: forall c. Transport c => HostName -> ServiceName -> IO c
startTCPClient :: ServiceName -> ServiceName -> IO c
startTCPClient ServiceName
host ServiceName
port = IO c -> IO c
forall a. IO a -> IO a
withSocketsDo (IO c -> IO c) -> IO c -> IO c
forall a b. (a -> b) -> a -> b
$ IO [AddrInfo]
resolve IO [AddrInfo] -> ([AddrInfo] -> IO c) -> IO c
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IOException -> [AddrInfo] -> IO c
tryOpen IOException
err
  where
    err :: IOException
    err :: IOException
err = IOErrorType
-> ServiceName -> Maybe Handle -> Maybe ServiceName -> IOException
mkIOError IOErrorType
NoSuchThing ServiceName
"no address" Maybe Handle
forall a. Maybe a
Nothing Maybe ServiceName
forall a. Maybe a
Nothing

    resolve :: IO [AddrInfo]
    resolve :: IO [AddrInfo]
resolve =
      let hints :: AddrInfo
hints = AddrInfo
defaultHints {addrSocketType :: SocketType
addrSocketType = SocketType
Stream}
       in Maybe AddrInfo
-> Maybe ServiceName -> Maybe ServiceName -> IO [AddrInfo]
getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
hints) (ServiceName -> Maybe ServiceName
forall a. a -> Maybe a
Just ServiceName
host) (ServiceName -> Maybe ServiceName
forall a. a -> Maybe a
Just ServiceName
port)

    tryOpen :: IOException -> [AddrInfo] -> IO c
    tryOpen :: IOException -> [AddrInfo] -> IO c
tryOpen IOException
e [] = IOException -> IO c
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO IOException
e
    tryOpen IOException
_ (AddrInfo
addr : [AddrInfo]
as) =
      IO c -> IO (Either IOException c)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.try (AddrInfo -> IO c
open AddrInfo
addr) IO (Either IOException c) -> (Either IOException c -> IO c) -> IO c
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (IOException -> IO c)
-> (c -> IO c) -> Either IOException c -> IO c
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (IOException -> [AddrInfo] -> IO c
`tryOpen` [AddrInfo]
as) c -> IO c
forall (f :: * -> *) a. Applicative f => a -> f a
pure

    open :: AddrInfo -> IO c
    open :: AddrInfo -> IO c
open AddrInfo
addr = do
      Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket (AddrInfo -> Family
addrFamily AddrInfo
addr) (AddrInfo -> SocketType
addrSocketType AddrInfo
addr) (AddrInfo -> ProtocolNumber
addrProtocol AddrInfo
addr)
      Socket -> SockAddr -> IO ()
connect Socket
sock (SockAddr -> IO ()) -> SockAddr -> IO ()
forall a b. (a -> b) -> a -> b
$ AddrInfo -> SockAddr
addrAddress AddrInfo
addr
      Socket -> IO c
forall c. Transport c => Socket -> IO c
getClientConnection Socket
sock

-- * TCP transport

newtype TCP = TCP {TCP -> Handle
tcpHandle :: Handle}

instance Transport TCP where
  transportName :: TProxy TCP -> ServiceName
transportName TProxy TCP
_ = ServiceName
"TCP"
  getServerConnection :: Socket -> IO TCP
getServerConnection = (Handle -> TCP) -> IO Handle -> IO TCP
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Handle -> TCP
TCP (IO Handle -> IO TCP) -> (Socket -> IO Handle) -> Socket -> IO TCP
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO Handle
getSocketHandle
  getClientConnection :: Socket -> IO TCP
getClientConnection = Socket -> IO TCP
forall c. Transport c => Socket -> IO c
getServerConnection
  closeConnection :: TCP -> IO ()
closeConnection = Handle -> IO ()
hClose (Handle -> IO ()) -> (TCP -> Handle) -> TCP -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCP -> Handle
tcpHandle
  cGet :: TCP -> Int -> IO ByteString
cGet = Handle -> Int -> IO ByteString
B.hGet (Handle -> Int -> IO ByteString)
-> (TCP -> Handle) -> TCP -> Int -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCP -> Handle
tcpHandle
  cPut :: TCP -> ByteString -> IO ()
cPut = Handle -> ByteString -> IO ()
B.hPut (Handle -> ByteString -> IO ())
-> (TCP -> Handle) -> TCP -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCP -> Handle
tcpHandle
  getLn :: TCP -> IO ByteString
getLn = (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
trimCR (IO ByteString -> IO ByteString)
-> (TCP -> IO ByteString) -> TCP -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> IO ByteString
B.hGetLine (Handle -> IO ByteString)
-> (TCP -> Handle) -> TCP -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TCP -> Handle
tcpHandle

getSocketHandle :: Socket -> IO Handle
getSocketHandle :: Socket -> IO Handle
getSocketHandle Socket
conn = do
  Handle
h <- Socket -> IOMode -> IO Handle
socketToHandle Socket
conn IOMode
ReadWriteMode
  Handle -> Bool -> IO ()
hSetBinaryMode Handle
h Bool
True
  Handle -> NewlineMode -> IO ()
hSetNewlineMode Handle
h NewlineMode :: Newline -> Newline -> NewlineMode
NewlineMode {inputNL :: Newline
inputNL = Newline
CRLF, outputNL :: Newline
outputNL = Newline
CRLF}
  Handle -> BufferMode -> IO ()
hSetBuffering Handle
h BufferMode
LineBuffering
  Handle -> IO Handle
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h

-- | Trim trailing CR from ByteString.
trimCR :: ByteString -> ByteString
trimCR :: ByteString -> ByteString
trimCR ByteString
"" = ByteString
""
trimCR ByteString
s = if ByteString -> Char
B.last ByteString
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'\r' then ByteString -> ByteString
B.init ByteString
s else ByteString
s

-- * SMP encrypted transport

data SMPVersion = SMPVersion Int Int Int Int
  deriving (SMPVersion -> SMPVersion -> Bool
(SMPVersion -> SMPVersion -> Bool)
-> (SMPVersion -> SMPVersion -> Bool) -> Eq SMPVersion
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SMPVersion -> SMPVersion -> Bool
$c/= :: SMPVersion -> SMPVersion -> Bool
== :: SMPVersion -> SMPVersion -> Bool
$c== :: SMPVersion -> SMPVersion -> Bool
Eq, Eq SMPVersion
Eq SMPVersion
-> (SMPVersion -> SMPVersion -> Ordering)
-> (SMPVersion -> SMPVersion -> Bool)
-> (SMPVersion -> SMPVersion -> Bool)
-> (SMPVersion -> SMPVersion -> Bool)
-> (SMPVersion -> SMPVersion -> Bool)
-> (SMPVersion -> SMPVersion -> SMPVersion)
-> (SMPVersion -> SMPVersion -> SMPVersion)
-> Ord SMPVersion
SMPVersion -> SMPVersion -> Bool
SMPVersion -> SMPVersion -> Ordering
SMPVersion -> SMPVersion -> SMPVersion
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SMPVersion -> SMPVersion -> SMPVersion
$cmin :: SMPVersion -> SMPVersion -> SMPVersion
max :: SMPVersion -> SMPVersion -> SMPVersion
$cmax :: SMPVersion -> SMPVersion -> SMPVersion
>= :: SMPVersion -> SMPVersion -> Bool
$c>= :: SMPVersion -> SMPVersion -> Bool
> :: SMPVersion -> SMPVersion -> Bool
$c> :: SMPVersion -> SMPVersion -> Bool
<= :: SMPVersion -> SMPVersion -> Bool
$c<= :: SMPVersion -> SMPVersion -> Bool
< :: SMPVersion -> SMPVersion -> Bool
$c< :: SMPVersion -> SMPVersion -> Bool
compare :: SMPVersion -> SMPVersion -> Ordering
$ccompare :: SMPVersion -> SMPVersion -> Ordering
$cp1Ord :: Eq SMPVersion
Ord)

major :: SMPVersion -> (Int, Int)
major :: SMPVersion -> (Int, Int)
major (SMPVersion Int
a Int
b Int
_ Int
_) = (Int
a, Int
b)

currentSMPVersion :: SMPVersion
currentSMPVersion :: SMPVersion
currentSMPVersion = Int -> Int -> Int -> Int -> SMPVersion
SMPVersion Int
0 Int
3 Int
2 Int
0

serializeSMPVersion :: SMPVersion -> ByteString
serializeSMPVersion :: SMPVersion -> ByteString
serializeSMPVersion (SMPVersion Int
a Int
b Int
c Int
d) = ByteString -> [ByteString] -> ByteString
B.intercalate ByteString
"." [Int -> ByteString
forall a. Show a => a -> ByteString
bshow Int
a, Int -> ByteString
forall a. Show a => a -> ByteString
bshow Int
b, Int -> ByteString
forall a. Show a => a -> ByteString
bshow Int
c, Int -> ByteString
forall a. Show a => a -> ByteString
bshow Int
d]

smpVersionP :: Parser SMPVersion
smpVersionP :: Parser SMPVersion
smpVersionP =
  let ver :: Parser ByteString Int
ver = Parser ByteString Int
forall a. Integral a => Parser a
A.decimal Parser ByteString Int
-> Parser ByteString Char -> Parser ByteString Int
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser ByteString Char
A.char Char
'.'
   in Int -> Int -> Int -> Int -> SMPVersion
SMPVersion (Int -> Int -> Int -> Int -> SMPVersion)
-> Parser ByteString Int
-> Parser ByteString (Int -> Int -> Int -> SMPVersion)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString Int
ver Parser ByteString (Int -> Int -> Int -> SMPVersion)
-> Parser ByteString Int
-> Parser ByteString (Int -> Int -> SMPVersion)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Int
ver Parser ByteString (Int -> Int -> SMPVersion)
-> Parser ByteString Int -> Parser ByteString (Int -> SMPVersion)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Int
ver Parser ByteString (Int -> SMPVersion)
-> Parser ByteString Int -> Parser SMPVersion
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Int
forall a. Integral a => Parser a
A.decimal

-- | The handle for SMP encrypted transport connection over Transport .
data THandle c = THandle
  { THandle c -> c
connection :: c,
    THandle c -> SessionKey
sndKey :: SessionKey,
    THandle c -> SessionKey
rcvKey :: SessionKey,
    THandle c -> Int
blockSize :: Int
  }

data SessionKey = SessionKey
  { SessionKey -> Key
aesKey :: C.Key,
    SessionKey -> IV
baseIV :: C.IV,
    SessionKey -> TVar Word32
counter :: TVar Word32
  }

data ClientHandshake = ClientHandshake
  { ClientHandshake -> Int
blockSize :: Int,
    ClientHandshake -> SessionKey
sndKey :: SessionKey,
    ClientHandshake -> SessionKey
rcvKey :: SessionKey
  }

-- | Error of SMP encrypted transport over TCP.
data TransportError
  = -- | error parsing transport block
    TEBadBlock
  | -- | block encryption error
    TEEncrypt
  | -- | block decryption error
    TEDecrypt
  | -- | transport handshake error
    TEHandshake HandshakeError
  deriving (TransportError -> TransportError -> Bool
(TransportError -> TransportError -> Bool)
-> (TransportError -> TransportError -> Bool) -> Eq TransportError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransportError -> TransportError -> Bool
$c/= :: TransportError -> TransportError -> Bool
== :: TransportError -> TransportError -> Bool
$c== :: TransportError -> TransportError -> Bool
Eq, (forall x. TransportError -> Rep TransportError x)
-> (forall x. Rep TransportError x -> TransportError)
-> Generic TransportError
forall x. Rep TransportError x -> TransportError
forall x. TransportError -> Rep TransportError x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep TransportError x -> TransportError
$cfrom :: forall x. TransportError -> Rep TransportError x
Generic, ReadPrec [TransportError]
ReadPrec TransportError
Int -> ReadS TransportError
ReadS [TransportError]
(Int -> ReadS TransportError)
-> ReadS [TransportError]
-> ReadPrec TransportError
-> ReadPrec [TransportError]
-> Read TransportError
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [TransportError]
$creadListPrec :: ReadPrec [TransportError]
readPrec :: ReadPrec TransportError
$creadPrec :: ReadPrec TransportError
readList :: ReadS [TransportError]
$creadList :: ReadS [TransportError]
readsPrec :: Int -> ReadS TransportError
$creadsPrec :: Int -> ReadS TransportError
Read, Int -> TransportError -> ShowS
[TransportError] -> ShowS
TransportError -> ServiceName
(Int -> TransportError -> ShowS)
-> (TransportError -> ServiceName)
-> ([TransportError] -> ShowS)
-> Show TransportError
forall a.
(Int -> a -> ShowS)
-> (a -> ServiceName) -> ([a] -> ShowS) -> Show a
showList :: [TransportError] -> ShowS
$cshowList :: [TransportError] -> ShowS
show :: TransportError -> ServiceName
$cshow :: TransportError -> ServiceName
showsPrec :: Int -> TransportError -> ShowS
$cshowsPrec :: Int -> TransportError -> ShowS
Show, Show TransportError
Typeable TransportError
Typeable TransportError
-> Show TransportError
-> (TransportError -> SomeException)
-> (SomeException -> Maybe TransportError)
-> (TransportError -> ServiceName)
-> Exception TransportError
SomeException -> Maybe TransportError
TransportError -> ServiceName
TransportError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> ServiceName)
-> Exception e
displayException :: TransportError -> ServiceName
$cdisplayException :: TransportError -> ServiceName
fromException :: SomeException -> Maybe TransportError
$cfromException :: SomeException -> Maybe TransportError
toException :: TransportError -> SomeException
$ctoException :: TransportError -> SomeException
$cp2Exception :: Show TransportError
$cp1Exception :: Typeable TransportError
Exception)

-- | Transport handshake error.
data HandshakeError
  = -- | encryption error
    ENCRYPT
  | -- | decryption error
    DECRYPT
  | -- | error parsing protocol version
    VERSION
  | -- | error parsing RSA key
    RSA_KEY
  | -- | error parsing server transport header or invalid block size
    HEADER
  | -- | error parsing AES keys
    AES_KEYS
  | -- | not matching RSA key hash
    BAD_HASH
  | -- | lower major agent version than protocol version
    MAJOR_VERSION
  | -- | TCP transport terminated
    TERMINATED
  deriving (HandshakeError -> HandshakeError -> Bool
(HandshakeError -> HandshakeError -> Bool)
-> (HandshakeError -> HandshakeError -> Bool) -> Eq HandshakeError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HandshakeError -> HandshakeError -> Bool
$c/= :: HandshakeError -> HandshakeError -> Bool
== :: HandshakeError -> HandshakeError -> Bool
$c== :: HandshakeError -> HandshakeError -> Bool
Eq, (forall x. HandshakeError -> Rep HandshakeError x)
-> (forall x. Rep HandshakeError x -> HandshakeError)
-> Generic HandshakeError
forall x. Rep HandshakeError x -> HandshakeError
forall x. HandshakeError -> Rep HandshakeError x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep HandshakeError x -> HandshakeError
$cfrom :: forall x. HandshakeError -> Rep HandshakeError x
Generic, ReadPrec [HandshakeError]
ReadPrec HandshakeError
Int -> ReadS HandshakeError
ReadS [HandshakeError]
(Int -> ReadS HandshakeError)
-> ReadS [HandshakeError]
-> ReadPrec HandshakeError
-> ReadPrec [HandshakeError]
-> Read HandshakeError
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [HandshakeError]
$creadListPrec :: ReadPrec [HandshakeError]
readPrec :: ReadPrec HandshakeError
$creadPrec :: ReadPrec HandshakeError
readList :: ReadS [HandshakeError]
$creadList :: ReadS [HandshakeError]
readsPrec :: Int -> ReadS HandshakeError
$creadsPrec :: Int -> ReadS HandshakeError
Read, Int -> HandshakeError -> ShowS
[HandshakeError] -> ShowS
HandshakeError -> ServiceName
(Int -> HandshakeError -> ShowS)
-> (HandshakeError -> ServiceName)
-> ([HandshakeError] -> ShowS)
-> Show HandshakeError
forall a.
(Int -> a -> ShowS)
-> (a -> ServiceName) -> ([a] -> ShowS) -> Show a
showList :: [HandshakeError] -> ShowS
$cshowList :: [HandshakeError] -> ShowS
show :: HandshakeError -> ServiceName
$cshow :: HandshakeError -> ServiceName
showsPrec :: Int -> HandshakeError -> ShowS
$cshowsPrec :: Int -> HandshakeError -> ShowS
Show, Show HandshakeError
Typeable HandshakeError
Typeable HandshakeError
-> Show HandshakeError
-> (HandshakeError -> SomeException)
-> (SomeException -> Maybe HandshakeError)
-> (HandshakeError -> ServiceName)
-> Exception HandshakeError
SomeException -> Maybe HandshakeError
HandshakeError -> ServiceName
HandshakeError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> ServiceName)
-> Exception e
displayException :: HandshakeError -> ServiceName
$cdisplayException :: HandshakeError -> ServiceName
fromException :: SomeException -> Maybe HandshakeError
$cfromException :: SomeException -> Maybe HandshakeError
toException :: HandshakeError -> SomeException
$ctoException :: HandshakeError -> SomeException
$cp2Exception :: Show HandshakeError
$cp1Exception :: Typeable HandshakeError
Exception)

instance Arbitrary TransportError where arbitrary :: Gen TransportError
arbitrary = Gen TransportError
forall a. (GArbitrary UnsizedOpts a, GUniformWeight a) => Gen a
genericArbitraryU

instance Arbitrary HandshakeError where arbitrary :: Gen HandshakeError
arbitrary = Gen HandshakeError
forall a. (GArbitrary UnsizedOpts a, GUniformWeight a) => Gen a
genericArbitraryU

-- | SMP encrypted transport error parser.
transportErrorP :: Parser TransportError
transportErrorP :: Parser TransportError
transportErrorP =
  Parser ByteString ByteString
"BLOCK" Parser ByteString ByteString
-> TransportError -> Parser TransportError
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TransportError
TEBadBlock
    Parser TransportError
-> Parser TransportError -> Parser TransportError
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser ByteString ByteString
"AES_ENCRYPT" Parser ByteString ByteString
-> TransportError -> Parser TransportError
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TransportError
TEEncrypt
    Parser TransportError
-> Parser TransportError -> Parser TransportError
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser ByteString ByteString
"AES_DECRYPT" Parser ByteString ByteString
-> TransportError -> Parser TransportError
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> TransportError
TEDecrypt
    Parser TransportError
-> Parser TransportError -> Parser TransportError
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> HandshakeError -> TransportError
TEHandshake (HandshakeError -> TransportError)
-> Parser ByteString HandshakeError -> Parser TransportError
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString HandshakeError
forall a. Read a => Parser a
parseRead1

-- | Serialize SMP encrypted transport error.
serializeTransportError :: TransportError -> ByteString
serializeTransportError :: TransportError -> ByteString
serializeTransportError = \case
  TransportError
TEEncrypt -> ByteString
"AES_ENCRYPT"
  TransportError
TEDecrypt -> ByteString
"AES_DECRYPT"
  TransportError
TEBadBlock -> ByteString
"BLOCK"
  TEHandshake HandshakeError
e -> HandshakeError -> ByteString
forall a. Show a => a -> ByteString
bshow HandshakeError
e

-- | Encrypt and send block to SMP encrypted transport.
tPutEncrypted :: Transport c => THandle c -> ByteString -> IO (Either TransportError ())
tPutEncrypted :: THandle c -> ByteString -> IO (Either TransportError ())
tPutEncrypted THandle {$sel:connection:THandle :: forall c. THandle c -> c
connection = c
c, SessionKey
sndKey :: SessionKey
$sel:sndKey:THandle :: forall c. THandle c -> SessionKey
sndKey, Int
blockSize :: Int
$sel:blockSize:THandle :: forall c. THandle c -> Int
blockSize} ByteString
block =
  SessionKey
-> Int
-> ByteString
-> IO (Either CryptoError (AuthTag, ByteString))
encryptBlock SessionKey
sndKey (Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
C.authTagSize) ByteString
block IO (Either CryptoError (AuthTag, ByteString))
-> (Either CryptoError (AuthTag, ByteString)
    -> IO (Either TransportError ()))
-> IO (Either TransportError ())
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left CryptoError
_ -> Either TransportError () -> IO (Either TransportError ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either TransportError () -> IO (Either TransportError ()))
-> Either TransportError () -> IO (Either TransportError ())
forall a b. (a -> b) -> a -> b
$ TransportError -> Either TransportError ()
forall a b. a -> Either a b
Left TransportError
TEEncrypt
    Right (AuthTag
authTag, ByteString
msg) -> () -> Either TransportError ()
forall a b. b -> Either a b
Right (() -> Either TransportError ())
-> IO () -> IO (Either TransportError ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> c -> ByteString -> IO ()
forall c. Transport c => c -> ByteString -> IO ()
cPut c
c (AuthTag -> ByteString
C.authTagToBS AuthTag
authTag ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
msg)

-- | Receive and decrypt block from SMP encrypted transport.
tGetEncrypted :: Transport c => THandle c -> IO (Either TransportError ByteString)
tGetEncrypted :: THandle c -> IO (Either TransportError ByteString)
tGetEncrypted THandle {$sel:connection:THandle :: forall c. THandle c -> c
connection = c
c, SessionKey
rcvKey :: SessionKey
$sel:rcvKey:THandle :: forall c. THandle c -> SessionKey
rcvKey, Int
blockSize :: Int
$sel:blockSize:THandle :: forall c. THandle c -> Int
blockSize} =
  c -> Int -> IO ByteString
forall c. Transport c => c -> Int -> IO ByteString
cGet c
c Int
blockSize IO ByteString
-> (ByteString -> IO (Either CryptoError ByteString))
-> IO (Either CryptoError ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SessionKey -> ByteString -> IO (Either CryptoError ByteString)
decryptBlock SessionKey
rcvKey IO (Either CryptoError ByteString)
-> (Either CryptoError ByteString
    -> IO (Either TransportError ByteString))
-> IO (Either TransportError ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Left CryptoError
_ -> Either TransportError ByteString
-> IO (Either TransportError ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either TransportError ByteString
 -> IO (Either TransportError ByteString))
-> Either TransportError ByteString
-> IO (Either TransportError ByteString)
forall a b. (a -> b) -> a -> b
$ TransportError -> Either TransportError ByteString
forall a b. a -> Either a b
Left TransportError
TEDecrypt
    Right ByteString
"" -> IO (Either TransportError ByteString)
forall a. IO a
ioe_EOF
    Right ByteString
msg -> Either TransportError ByteString
-> IO (Either TransportError ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either TransportError ByteString
 -> IO (Either TransportError ByteString))
-> Either TransportError ByteString
-> IO (Either TransportError ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> Either TransportError ByteString
forall a b. b -> Either a b
Right ByteString
msg

encryptBlock :: SessionKey -> Int -> ByteString -> IO (Either C.CryptoError (AuthTag, ByteString))
encryptBlock :: SessionKey
-> Int
-> ByteString
-> IO (Either CryptoError (AuthTag, ByteString))
encryptBlock k :: SessionKey
k@SessionKey {Key
aesKey :: Key
$sel:aesKey:SessionKey :: SessionKey -> Key
aesKey} Int
size ByteString
block = do
  IV
ivBytes <- SessionKey -> IO IV
makeNextIV SessionKey
k
  ExceptT CryptoError IO (AuthTag, ByteString)
-> IO (Either CryptoError (AuthTag, ByteString))
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT CryptoError IO (AuthTag, ByteString)
 -> IO (Either CryptoError (AuthTag, ByteString)))
-> ExceptT CryptoError IO (AuthTag, ByteString)
-> IO (Either CryptoError (AuthTag, ByteString))
forall a b. (a -> b) -> a -> b
$ Key
-> IV
-> Int
-> ByteString
-> ExceptT CryptoError IO (AuthTag, ByteString)
C.encryptAES Key
aesKey IV
ivBytes Int
size ByteString
block

decryptBlock :: SessionKey -> ByteString -> IO (Either C.CryptoError ByteString)
decryptBlock :: SessionKey -> ByteString -> IO (Either CryptoError ByteString)
decryptBlock k :: SessionKey
k@SessionKey {Key
aesKey :: Key
$sel:aesKey:SessionKey :: SessionKey -> Key
aesKey} ByteString
block = do
  let (ByteString
authTag, ByteString
msg') = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
C.authTagSize ByteString
block
  IV
ivBytes <- SessionKey -> IO IV
makeNextIV SessionKey
k
  ExceptT CryptoError IO ByteString
-> IO (Either CryptoError ByteString)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT CryptoError IO ByteString
 -> IO (Either CryptoError ByteString))
-> ExceptT CryptoError IO ByteString
-> IO (Either CryptoError ByteString)
forall a b. (a -> b) -> a -> b
$ Key
-> IV -> ByteString -> AuthTag -> ExceptT CryptoError IO ByteString
C.decryptAES Key
aesKey IV
ivBytes ByteString
msg' (ByteString -> AuthTag
C.bsToAuthTag ByteString
authTag)

makeNextIV :: SessionKey -> IO C.IV
makeNextIV :: SessionKey -> IO IV
makeNextIV SessionKey {IV
baseIV :: IV
$sel:baseIV:SessionKey :: SessionKey -> IV
baseIV, TVar Word32
counter :: TVar Word32
$sel:counter:SessionKey :: SessionKey -> TVar Word32
counter} = STM IV -> IO IV
forall (m :: * -> *) a. MonadIO m => STM a -> m a
atomically (STM IV -> IO IV) -> STM IV -> IO IV
forall a b. (a -> b) -> a -> b
$ do
  Word32
c <- TVar Word32 -> STM Word32
forall a. TVar a -> STM a
readTVar TVar Word32
counter
  TVar Word32 -> Word32 -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Word32
counter (Word32 -> STM ()) -> Word32 -> STM ()
forall a b. (a -> b) -> a -> b
$ Word32
c Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word32
1
  IV -> STM IV
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IV -> STM IV) -> IV -> STM IV
forall a b. (a -> b) -> a -> b
$ Word32 -> IV
iv Word32
c
  where
    (ByteString
start, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
4 (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ IV -> ByteString
C.unIV IV
baseIV
    iv :: Word32 -> IV
iv Word32
c = ByteString -> IV
C.IV (ByteString -> IV) -> ByteString -> IV
forall a b. (a -> b) -> a -> b
$ (ByteString
start ByteString -> ByteString -> ByteString
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
`xor` Word32 -> ByteString
encodeWord32 Word32
c) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
rest

-- | Server SMP encrypted transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
--
-- The numbers in function names refer to the steps in the document.
serverHandshake :: forall c. Transport c => c -> C.FullKeyPair -> ExceptT TransportError IO (THandle c)
serverHandshake :: c -> FullKeyPair -> ExceptT TransportError IO (THandle c)
serverHandshake c
c (PublicKey
k, FullPrivateKey
pk) = do
  IO () -> ExceptT TransportError IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ()
sendHeaderAndPublicKey_1
  ByteString
encryptedKeys <- ExceptT TransportError IO ByteString
receiveEncryptedKeys_4
  -- TODO server currently ignores blockSize returned by the client
  -- this is reserved for future support of streams
  ClientHandshake {$sel:blockSize:ClientHandshake :: ClientHandshake -> Int
blockSize = Int
_, SessionKey
sndKey :: SessionKey
$sel:sndKey:ClientHandshake :: ClientHandshake -> SessionKey
sndKey, SessionKey
rcvKey :: SessionKey
$sel:rcvKey:ClientHandshake :: ClientHandshake -> SessionKey
rcvKey} <- ByteString -> ExceptT TransportError IO ClientHandshake
decryptParseKeys_5 ByteString
encryptedKeys
  THandle c
th <- IO (THandle c) -> ExceptT TransportError IO (THandle c)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (THandle c) -> ExceptT TransportError IO (THandle c))
-> IO (THandle c) -> ExceptT TransportError IO (THandle c)
forall a b. (a -> b) -> a -> b
$ c -> SessionKey -> SessionKey -> Int -> IO (THandle c)
forall c. c -> SessionKey -> SessionKey -> Int -> IO (THandle c)
transportHandle c
c SessionKey
rcvKey SessionKey
sndKey Int
transportBlockSize -- keys are swapped here
  THandle c -> ExceptT TransportError IO ()
sendWelcome_6 THandle c
th
  THandle c -> ExceptT TransportError IO (THandle c)
forall (f :: * -> *) a. Applicative f => a -> f a
pure THandle c
th
  where
    sendHeaderAndPublicKey_1 :: IO ()
    sendHeaderAndPublicKey_1 :: IO ()
sendHeaderAndPublicKey_1 = do
      let sKey :: ByteString
sKey = PublicKey -> ByteString
C.encodePubKey PublicKey
k
          header :: ServerHeader
header = ServerHeader :: Int -> Int -> ServerHeader
ServerHeader {$sel:blockSize:ServerHeader :: Int
blockSize = Int
transportBlockSize, $sel:keySize:ServerHeader :: Int
keySize = ByteString -> Int
B.length ByteString
sKey}
      c -> ByteString -> IO ()
forall c. Transport c => c -> ByteString -> IO ()
cPut c
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ ServerHeader -> ByteString
binaryServerHeader ServerHeader
header
      c -> ByteString -> IO ()
forall c. Transport c => c -> ByteString -> IO ()
cPut c
c ByteString
sKey
    receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString
    receiveEncryptedKeys_4 :: ExceptT TransportError IO ByteString
receiveEncryptedKeys_4 =
      IO ByteString -> ExceptT TransportError IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (c -> Int -> IO ByteString
forall c. Transport c => c -> Int -> IO ByteString
cGet c
c (Int -> IO ByteString) -> Int -> IO ByteString
forall a b. (a -> b) -> a -> b
$ PublicKey -> Int
C.publicKeySize PublicKey
k) ExceptT TransportError IO ByteString
-> (ByteString -> ExceptT TransportError IO ByteString)
-> ExceptT TransportError IO ByteString
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        ByteString
"" -> TransportError -> ExceptT TransportError IO ByteString
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (TransportError -> ExceptT TransportError IO ByteString)
-> TransportError -> ExceptT TransportError IO ByteString
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
TERMINATED
        ByteString
ks -> ByteString -> ExceptT TransportError IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
ks
    decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO ClientHandshake
    decryptParseKeys_5 :: ByteString -> ExceptT TransportError IO ClientHandshake
decryptParseKeys_5 ByteString
encKeys =
      (CryptoError -> TransportError)
-> ExceptT CryptoError IO ByteString
-> ExceptT TransportError IO ByteString
forall (m :: * -> *) e' e a.
(MonadIO m, MonadError e' m) =>
(e -> e') -> ExceptT e IO a -> m a
liftError (TransportError -> CryptoError -> TransportError
forall a b. a -> b -> a
const (TransportError -> CryptoError -> TransportError)
-> TransportError -> CryptoError -> TransportError
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
DECRYPT) (FullPrivateKey -> ByteString -> ExceptT CryptoError IO ByteString
forall k.
PrivateKey k =>
k -> ByteString -> ExceptT CryptoError IO ByteString
C.decryptOAEP FullPrivateKey
pk ByteString
encKeys)
        ExceptT TransportError IO ByteString
-> (ByteString -> ExceptT TransportError IO ClientHandshake)
-> ExceptT TransportError IO ClientHandshake
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either TransportError ClientHandshake
-> ExceptT TransportError IO ClientHandshake
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either TransportError ClientHandshake
 -> ExceptT TransportError IO ClientHandshake)
-> (ByteString -> Either TransportError ClientHandshake)
-> ByteString
-> ExceptT TransportError IO ClientHandshake
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Either TransportError ClientHandshake
parseClientHandshake
    sendWelcome_6 :: THandle c -> ExceptT TransportError IO ()
    sendWelcome_6 :: THandle c -> ExceptT TransportError IO ()
sendWelcome_6 THandle c
th = IO (Either TransportError ()) -> ExceptT TransportError IO ()
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (IO (Either TransportError ()) -> ExceptT TransportError IO ())
-> (ByteString -> IO (Either TransportError ()))
-> ByteString
-> ExceptT TransportError IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. THandle c -> ByteString -> IO (Either TransportError ())
forall c.
Transport c =>
THandle c -> ByteString -> IO (Either TransportError ())
tPutEncrypted THandle c
th (ByteString -> ExceptT TransportError IO ())
-> ByteString -> ExceptT TransportError IO ()
forall a b. (a -> b) -> a -> b
$ SMPVersion -> ByteString
serializeSMPVersion SMPVersion
currentSMPVersion ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" "

-- | Client SMP encrypted transport handshake.
--
-- See https://github.com/simplex-chat/simplexmq/blob/master/protocol/simplex-messaging.md#appendix-a
--
-- The numbers in function names refer to the steps in the document.
clientHandshake :: forall c. Transport c => c -> Maybe C.KeyHash -> ExceptT TransportError IO (THandle c)
clientHandshake :: c -> Maybe KeyHash -> ExceptT TransportError IO (THandle c)
clientHandshake c
c Maybe KeyHash
keyHash = do
  (PublicKey
k, Int
blkSize) <- ExceptT TransportError IO (PublicKey, Int)
getHeaderAndPublicKey_1_2
  -- TODO currently client always uses the blkSize returned by the server
  keys :: ClientHandshake
keys@ClientHandshake {SessionKey
sndKey :: SessionKey
$sel:sndKey:ClientHandshake :: ClientHandshake -> SessionKey
sndKey, SessionKey
rcvKey :: SessionKey
$sel:rcvKey:ClientHandshake :: ClientHandshake -> SessionKey
rcvKey} <- IO ClientHandshake -> ExceptT TransportError IO ClientHandshake
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ClientHandshake -> ExceptT TransportError IO ClientHandshake)
-> IO ClientHandshake -> ExceptT TransportError IO ClientHandshake
forall a b. (a -> b) -> a -> b
$ Int -> IO ClientHandshake
generateKeys_3 Int
blkSize
  PublicKey -> ClientHandshake -> ExceptT TransportError IO ()
sendEncryptedKeys_4 PublicKey
k ClientHandshake
keys
  THandle c
th <- IO (THandle c) -> ExceptT TransportError IO (THandle c)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (THandle c) -> ExceptT TransportError IO (THandle c))
-> IO (THandle c) -> ExceptT TransportError IO (THandle c)
forall a b. (a -> b) -> a -> b
$ c -> SessionKey -> SessionKey -> Int -> IO (THandle c)
forall c. c -> SessionKey -> SessionKey -> Int -> IO (THandle c)
transportHandle c
c SessionKey
sndKey SessionKey
rcvKey Int
blkSize
  THandle c -> ExceptT TransportError IO SMPVersion
getWelcome_6 THandle c
th ExceptT TransportError IO SMPVersion
-> (SMPVersion -> ExceptT TransportError IO ())
-> ExceptT TransportError IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SMPVersion -> ExceptT TransportError IO ()
checkVersion
  THandle c -> ExceptT TransportError IO (THandle c)
forall (f :: * -> *) a. Applicative f => a -> f a
pure THandle c
th
  where
    getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (C.PublicKey, Int)
    getHeaderAndPublicKey_1_2 :: ExceptT TransportError IO (PublicKey, Int)
getHeaderAndPublicKey_1_2 = do
      ByteString
header <- IO ByteString -> ExceptT TransportError IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (c -> Int -> IO ByteString
forall c. Transport c => c -> Int -> IO ByteString
cGet c
c Int
serverHeaderSize)
      ServerHeader {Int
blockSize :: Int
$sel:blockSize:ServerHeader :: ServerHeader -> Int
blockSize, Int
keySize :: Int
$sel:keySize:ServerHeader :: ServerHeader -> Int
keySize} <- Either TransportError ServerHeader
-> ExceptT TransportError IO ServerHeader
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either TransportError ServerHeader
 -> ExceptT TransportError IO ServerHeader)
-> Either TransportError ServerHeader
-> ExceptT TransportError IO ServerHeader
forall a b. (a -> b) -> a -> b
$ Parser ServerHeader
-> TransportError
-> ByteString
-> Either TransportError ServerHeader
forall a e. Parser a -> e -> ByteString -> Either e a
parse Parser ServerHeader
serverHeaderP (HandshakeError -> TransportError
TEHandshake HandshakeError
HEADER) ByteString
header
      Bool
-> ExceptT TransportError IO () -> ExceptT TransportError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
blockSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
transportBlockSize Bool -> Bool -> Bool
|| Int
blockSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
maxTransportBlockSize) (ExceptT TransportError IO () -> ExceptT TransportError IO ())
-> ExceptT TransportError IO () -> ExceptT TransportError IO ()
forall a b. (a -> b) -> a -> b
$
        TransportError -> ExceptT TransportError IO ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TransportError -> ExceptT TransportError IO ())
-> TransportError -> ExceptT TransportError IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
HEADER
      ByteString
s <- IO ByteString -> ExceptT TransportError IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> ExceptT TransportError IO ByteString)
-> IO ByteString -> ExceptT TransportError IO ByteString
forall a b. (a -> b) -> a -> b
$ c -> Int -> IO ByteString
forall c. Transport c => c -> Int -> IO ByteString
cGet c
c Int
keySize
      ExceptT TransportError IO ()
-> (KeyHash -> ExceptT TransportError IO ())
-> Maybe KeyHash
-> ExceptT TransportError IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (() -> ExceptT TransportError IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) (ByteString -> KeyHash -> ExceptT TransportError IO ()
validateKeyHash_2 ByteString
s) Maybe KeyHash
keyHash
      PublicKey
key <- Either TransportError PublicKey
-> ExceptT TransportError IO PublicKey
forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
liftEither (Either TransportError PublicKey
 -> ExceptT TransportError IO PublicKey)
-> Either TransportError PublicKey
-> ExceptT TransportError IO PublicKey
forall a b. (a -> b) -> a -> b
$ ByteString -> Either TransportError PublicKey
parseKey ByteString
s
      (PublicKey, Int) -> ExceptT TransportError IO (PublicKey, Int)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PublicKey
key, Int
blockSize)
    parseKey :: ByteString -> Either TransportError C.PublicKey
    parseKey :: ByteString -> Either TransportError PublicKey
parseKey = (ServiceName -> TransportError)
-> Either ServiceName PublicKey -> Either TransportError PublicKey
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (TransportError -> ServiceName -> TransportError
forall a b. a -> b -> a
const (TransportError -> ServiceName -> TransportError)
-> TransportError -> ServiceName -> TransportError
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
RSA_KEY) (Either ServiceName PublicKey -> Either TransportError PublicKey)
-> (ByteString -> Either ServiceName PublicKey)
-> ByteString
-> Either TransportError PublicKey
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parser PublicKey -> ByteString -> Either ServiceName PublicKey
forall a. Parser a -> ByteString -> Either ServiceName a
parseAll Parser PublicKey
C.binaryPubKeyP
    validateKeyHash_2 :: ByteString -> C.KeyHash -> ExceptT TransportError IO ()
    validateKeyHash_2 :: ByteString -> KeyHash -> ExceptT TransportError IO ()
validateKeyHash_2 ByteString
k (C.KeyHash ByteString
kHash)
      | ByteString -> ByteString
C.sha256Hash ByteString
k ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
kHash = () -> ExceptT TransportError IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      | Bool
otherwise = TransportError -> ExceptT TransportError IO ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (TransportError -> ExceptT TransportError IO ())
-> TransportError -> ExceptT TransportError IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
BAD_HASH
    generateKeys_3 :: Int -> IO ClientHandshake
    generateKeys_3 :: Int -> IO ClientHandshake
generateKeys_3 Int
blkSize = Int -> SessionKey -> SessionKey -> ClientHandshake
ClientHandshake Int
blkSize (SessionKey -> SessionKey -> ClientHandshake)
-> IO SessionKey -> IO (SessionKey -> ClientHandshake)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO SessionKey
generateKey IO (SessionKey -> ClientHandshake)
-> IO SessionKey -> IO ClientHandshake
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO SessionKey
generateKey
    generateKey :: IO SessionKey
    generateKey :: IO SessionKey
generateKey = do
      Key
aesKey <- IO Key
C.randomAesKey
      IV
baseIV <- IO IV
C.randomIV
      SessionKey -> IO SessionKey
forall (f :: * -> *) a. Applicative f => a -> f a
pure SessionKey :: Key -> IV -> TVar Word32 -> SessionKey
SessionKey {Key
aesKey :: Key
$sel:aesKey:SessionKey :: Key
aesKey, IV
baseIV :: IV
$sel:baseIV:SessionKey :: IV
baseIV, $sel:counter:SessionKey :: TVar Word32
counter = TVar Word32
forall a. HasCallStack => a
undefined}
    sendEncryptedKeys_4 :: C.PublicKey -> ClientHandshake -> ExceptT TransportError IO ()
    sendEncryptedKeys_4 :: PublicKey -> ClientHandshake -> ExceptT TransportError IO ()
sendEncryptedKeys_4 PublicKey
k ClientHandshake
keys =
      (CryptoError -> TransportError)
-> ExceptT CryptoError IO ByteString
-> ExceptT TransportError IO ByteString
forall (m :: * -> *) e' e a.
(MonadIO m, MonadError e' m) =>
(e -> e') -> ExceptT e IO a -> m a
liftError (TransportError -> CryptoError -> TransportError
forall a b. a -> b -> a
const (TransportError -> CryptoError -> TransportError)
-> TransportError -> CryptoError -> TransportError
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
ENCRYPT) (PublicKey -> ByteString -> ExceptT CryptoError IO ByteString
C.encryptOAEP PublicKey
k (ByteString -> ExceptT CryptoError IO ByteString)
-> ByteString -> ExceptT CryptoError IO ByteString
forall a b. (a -> b) -> a -> b
$ ClientHandshake -> ByteString
serializeClientHandshake ClientHandshake
keys)
        ExceptT TransportError IO ByteString
-> (ByteString -> ExceptT TransportError IO ())
-> ExceptT TransportError IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> ExceptT TransportError IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT TransportError IO ())
-> (ByteString -> IO ())
-> ByteString
-> ExceptT TransportError IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. c -> ByteString -> IO ()
forall c. Transport c => c -> ByteString -> IO ()
cPut c
c
    getWelcome_6 :: THandle c -> ExceptT TransportError IO SMPVersion
    getWelcome_6 :: THandle c -> ExceptT TransportError IO SMPVersion
getWelcome_6 THandle c
th = IO (Either TransportError SMPVersion)
-> ExceptT TransportError IO SMPVersion
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (IO (Either TransportError SMPVersion)
 -> ExceptT TransportError IO SMPVersion)
-> IO (Either TransportError SMPVersion)
-> ExceptT TransportError IO SMPVersion
forall a b. (a -> b) -> a -> b
$ (Either TransportError ByteString
-> (ByteString -> Either TransportError SMPVersion)
-> Either TransportError SMPVersion
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Either TransportError SMPVersion
parseSMPVersion) (Either TransportError ByteString
 -> Either TransportError SMPVersion)
-> IO (Either TransportError ByteString)
-> IO (Either TransportError SMPVersion)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> THandle c -> IO (Either TransportError ByteString)
forall c.
Transport c =>
THandle c -> IO (Either TransportError ByteString)
tGetEncrypted THandle c
th
    parseSMPVersion :: ByteString -> Either TransportError SMPVersion
    parseSMPVersion :: ByteString -> Either TransportError SMPVersion
parseSMPVersion = (ServiceName -> TransportError)
-> Either ServiceName SMPVersion
-> Either TransportError SMPVersion
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (TransportError -> ServiceName -> TransportError
forall a b. a -> b -> a
const (TransportError -> ServiceName -> TransportError)
-> TransportError -> ServiceName -> TransportError
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
VERSION) (Either ServiceName SMPVersion -> Either TransportError SMPVersion)
-> (ByteString -> Either ServiceName SMPVersion)
-> ByteString
-> Either TransportError SMPVersion
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parser SMPVersion -> ByteString -> Either ServiceName SMPVersion
forall a. Parser a -> ByteString -> Either ServiceName a
A.parseOnly (Parser SMPVersion
smpVersionP Parser SMPVersion -> Parser ByteString Char -> Parser SMPVersion
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString Char
A.space)
    checkVersion :: SMPVersion -> ExceptT TransportError IO ()
    checkVersion :: SMPVersion -> ExceptT TransportError IO ()
checkVersion SMPVersion
smpVersion =
      Bool
-> ExceptT TransportError IO () -> ExceptT TransportError IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SMPVersion -> (Int, Int)
major SMPVersion
smpVersion (Int, Int) -> (Int, Int) -> Bool
forall a. Ord a => a -> a -> Bool
> SMPVersion -> (Int, Int)
major SMPVersion
currentSMPVersion) (ExceptT TransportError IO () -> ExceptT TransportError IO ())
-> (TransportError -> ExceptT TransportError IO ())
-> TransportError
-> ExceptT TransportError IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError -> ExceptT TransportError IO ()
forall (m :: * -> *) e a. Monad m => e -> ExceptT e m a
throwE (TransportError -> ExceptT TransportError IO ())
-> TransportError -> ExceptT TransportError IO ()
forall a b. (a -> b) -> a -> b
$
        HandshakeError -> TransportError
TEHandshake HandshakeError
MAJOR_VERSION

data ServerHeader = ServerHeader {ServerHeader -> Int
blockSize :: Int, ServerHeader -> Int
keySize :: Int}
  deriving (ServerHeader -> ServerHeader -> Bool
(ServerHeader -> ServerHeader -> Bool)
-> (ServerHeader -> ServerHeader -> Bool) -> Eq ServerHeader
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ServerHeader -> ServerHeader -> Bool
$c/= :: ServerHeader -> ServerHeader -> Bool
== :: ServerHeader -> ServerHeader -> Bool
$c== :: ServerHeader -> ServerHeader -> Bool
Eq, Int -> ServerHeader -> ShowS
[ServerHeader] -> ShowS
ServerHeader -> ServiceName
(Int -> ServerHeader -> ShowS)
-> (ServerHeader -> ServiceName)
-> ([ServerHeader] -> ShowS)
-> Show ServerHeader
forall a.
(Int -> a -> ShowS)
-> (a -> ServiceName) -> ([a] -> ShowS) -> Show a
showList :: [ServerHeader] -> ShowS
$cshowList :: [ServerHeader] -> ShowS
show :: ServerHeader -> ServiceName
$cshow :: ServerHeader -> ServiceName
showsPrec :: Int -> ServerHeader -> ShowS
$cshowsPrec :: Int -> ServerHeader -> ShowS
Show)

binaryRsaTransport :: Int
binaryRsaTransport :: Int
binaryRsaTransport = Int
0

transportBlockSize :: Int
transportBlockSize :: Int
transportBlockSize = Int
4096

maxTransportBlockSize :: Int
maxTransportBlockSize :: Int
maxTransportBlockSize = Int
65536

serverHeaderSize :: Int
serverHeaderSize :: Int
serverHeaderSize = Int
8

binaryServerHeader :: ServerHeader -> ByteString
binaryServerHeader :: ServerHeader -> ByteString
binaryServerHeader ServerHeader {Int
blockSize :: Int
$sel:blockSize:ServerHeader :: ServerHeader -> Int
blockSize, Int
keySize :: Int
$sel:keySize:ServerHeader :: ServerHeader -> Int
keySize} =
  Int -> ByteString
forall a. Enum a => a -> ByteString
encodeEnum32 Int
blockSize ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> ByteString
forall a. Enum a => a -> ByteString
encodeEnum16 Int
binaryRsaTransport ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> ByteString
forall a. Enum a => a -> ByteString
encodeEnum16 Int
keySize

serverHeaderP :: Parser ServerHeader
serverHeaderP :: Parser ServerHeader
serverHeaderP = Int -> Int -> ServerHeader
ServerHeader (Int -> Int -> ServerHeader)
-> Parser ByteString Int -> Parser ByteString (Int -> ServerHeader)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString Int
int32 Parser ByteString (Int -> ServerHeader)
-> Parser ByteString () -> Parser ByteString (Int -> ServerHeader)
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
binaryRsaTransportP Parser ByteString (Int -> ServerHeader)
-> Parser ByteString Int -> Parser ServerHeader
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString Int
int16

serializeClientHandshake :: ClientHandshake -> ByteString
serializeClientHandshake :: ClientHandshake -> ByteString
serializeClientHandshake ClientHandshake {Int
blockSize :: Int
$sel:blockSize:ClientHandshake :: ClientHandshake -> Int
blockSize, SessionKey
sndKey :: SessionKey
$sel:sndKey:ClientHandshake :: ClientHandshake -> SessionKey
sndKey, SessionKey
rcvKey :: SessionKey
$sel:rcvKey:ClientHandshake :: ClientHandshake -> SessionKey
rcvKey} =
  Int -> ByteString
forall a. Enum a => a -> ByteString
encodeEnum32 Int
blockSize ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> ByteString
forall a. Enum a => a -> ByteString
encodeEnum16 Int
binaryRsaTransport ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> SessionKey -> ByteString
serializeKey SessionKey
sndKey ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> SessionKey -> ByteString
serializeKey SessionKey
rcvKey
  where
    serializeKey :: SessionKey -> ByteString
    serializeKey :: SessionKey -> ByteString
serializeKey SessionKey {Key
aesKey :: Key
$sel:aesKey:SessionKey :: SessionKey -> Key
aesKey, IV
baseIV :: IV
$sel:baseIV:SessionKey :: SessionKey -> IV
baseIV} = Key -> ByteString
C.unKey Key
aesKey ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> IV -> ByteString
C.unIV IV
baseIV

clientHandshakeP :: Parser ClientHandshake
clientHandshakeP :: Parser ClientHandshake
clientHandshakeP = Int -> SessionKey -> SessionKey -> ClientHandshake
ClientHandshake (Int -> SessionKey -> SessionKey -> ClientHandshake)
-> Parser ByteString Int
-> Parser ByteString (SessionKey -> SessionKey -> ClientHandshake)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser ByteString Int
int32 Parser ByteString (SessionKey -> SessionKey -> ClientHandshake)
-> Parser ByteString ()
-> Parser ByteString (SessionKey -> SessionKey -> ClientHandshake)
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ByteString ()
binaryRsaTransportP Parser ByteString (SessionKey -> SessionKey -> ClientHandshake)
-> Parser ByteString SessionKey
-> Parser ByteString (SessionKey -> ClientHandshake)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString SessionKey
keyP Parser ByteString (SessionKey -> ClientHandshake)
-> Parser ByteString SessionKey -> Parser ClientHandshake
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString SessionKey
keyP
  where
    keyP :: Parser SessionKey
    keyP :: Parser ByteString SessionKey
keyP = do
      Key
aesKey <- Parser Key
C.aesKeyP
      IV
baseIV <- Parser IV
C.ivP
      SessionKey -> Parser ByteString SessionKey
forall (f :: * -> *) a. Applicative f => a -> f a
pure SessionKey :: Key -> IV -> TVar Word32 -> SessionKey
SessionKey {Key
aesKey :: Key
$sel:aesKey:SessionKey :: Key
aesKey, IV
baseIV :: IV
$sel:baseIV:SessionKey :: IV
baseIV, $sel:counter:SessionKey :: TVar Word32
counter = TVar Word32
forall a. HasCallStack => a
undefined}

int32 :: Parser Int
int32 :: Parser ByteString Int
int32 = ByteString -> Int
forall a. Num a => ByteString -> a
decodeNum32 (ByteString -> Int)
-> Parser ByteString ByteString -> Parser ByteString Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Parser ByteString ByteString
A.take Int
4

int16 :: Parser Int
int16 :: Parser ByteString Int
int16 = ByteString -> Int
forall a. Num a => ByteString -> a
decodeNum16 (ByteString -> Int)
-> Parser ByteString ByteString -> Parser ByteString Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Parser ByteString ByteString
A.take Int
2

binaryRsaTransportP :: Parser ()
binaryRsaTransportP :: Parser ByteString ()
binaryRsaTransportP = Int -> Parser ByteString ()
binaryRsa (Int -> Parser ByteString ())
-> Parser ByteString Int -> Parser ByteString ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Parser ByteString Int
int16
  where
    binaryRsa :: Int -> Parser ()
    binaryRsa :: Int -> Parser ByteString ()
binaryRsa Int
n
      | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
binaryRsaTransport = () -> Parser ByteString ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      | Bool
otherwise = ServiceName -> Parser ByteString ()
forall (m :: * -> *) a. MonadFail m => ServiceName -> m a
fail ServiceName
"unknown transport mode"

parseClientHandshake :: ByteString -> Either TransportError ClientHandshake
parseClientHandshake :: ByteString -> Either TransportError ClientHandshake
parseClientHandshake = Parser ClientHandshake
-> TransportError
-> ByteString
-> Either TransportError ClientHandshake
forall a e. Parser a -> e -> ByteString -> Either e a
parse Parser ClientHandshake
clientHandshakeP (TransportError
 -> ByteString -> Either TransportError ClientHandshake)
-> TransportError
-> ByteString
-> Either TransportError ClientHandshake
forall a b. (a -> b) -> a -> b
$ HandshakeError -> TransportError
TEHandshake HandshakeError
AES_KEYS

transportHandle :: c -> SessionKey -> SessionKey -> Int -> IO (THandle c)
transportHandle :: c -> SessionKey -> SessionKey -> Int -> IO (THandle c)
transportHandle c
c SessionKey
sk SessionKey
rk Int
blockSize = do
  TVar Word32
sndCounter <- Word32 -> IO (TVar Word32)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Word32
0
  TVar Word32
rcvCounter <- Word32 -> IO (TVar Word32)
forall (m :: * -> *) a. MonadIO m => a -> m (TVar a)
newTVarIO Word32
0
  THandle c -> IO (THandle c)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    THandle :: forall c. c -> SessionKey -> SessionKey -> Int -> THandle c
THandle
      { $sel:connection:THandle :: c
connection = c
c,
        $sel:sndKey:THandle :: SessionKey
sndKey = SessionKey
sk {$sel:counter:SessionKey :: TVar Word32
counter = TVar Word32
sndCounter},
        $sel:rcvKey:THandle :: SessionKey
rcvKey = SessionKey
rk {$sel:counter:SessionKey :: TVar Word32
counter = TVar Word32
rcvCounter},
        Int
blockSize :: Int
$sel:blockSize:THandle :: Int
blockSize
      }