{-# LANGUAGE DerivingVia, DeriveGeneric, TypeApplications, ExistentialQuantification #-} module Curryer.Test.Basic where import Test.Tasty import Test.Tasty.HUnit import Codec.Winery import GHC.Generics import Control.Concurrent.MVar import Network.Socket (SockAddr(..)) import Control.Concurrent.Async import Control.Monad import Control.Concurrent import Control.Concurrent.STM import Control.Exception import Data.List import Network.RPC.Curryer.Server import Network.RPC.Curryer.Client -- TODO: add test for nested calls testTree :: TestTree testTree = testGroup "basic" [testCase "simple" testSimpleCall ,testCase "client async" testAsyncServerCall ,testCase "server async" testAsyncClientCall ,testCase "client sync timeout" testSyncClientCallTimeout ,testCase "server-side exception" testSyncException ,testCase "multi-threaded client" testMultithreadedClient ,testCase "server state" testServerState ,testCase "request handler throws timeout" testRequestHandlerThrowTimeout ] data AddTwoNumbersReq = AddTwoNumbersReq Int Int deriving (Generic, Show) deriving Serialise via WineryVariant AddTwoNumbersReq data TestCallMeBackReq = TestCallMeBackReq String deriving (Generic, Show, Eq) deriving Serialise via WineryVariant TestCallMeBackReq data TestAsyncReq = TestAsyncReq String deriving (Generic, Show) deriving Serialise via WineryVariant TestAsyncReq data DelayMicrosecondsReq = DelayMicrosecondsReq Int deriving (Generic, Show) deriving Serialise via WineryVariant DelayMicrosecondsReq data RoundtripStringReq = RoundtripStringReq String deriving (Generic, Show) deriving Serialise via WineryVariant RoundtripStringReq data ThrowServerSideExceptionReq = ThrowServerSideExceptionReq deriving (Generic, Show) deriving Serialise via WineryVariant ThrowServerSideExceptionReq data ChangeServerState = ChangeServerState deriving (Generic, Show) deriving Serialise via WineryVariant ChangeServerState --used to server -> client async request data AsyncHelloReq = AsyncHelloReq String deriving (Generic, Show) deriving Serialise via WineryVariant AsyncHelloReq data ThrowTimeout = ThrowTimeout deriving (Generic, Show) deriving Serialise via WineryVariant ThrowTimeout testServerRequestHandlers :: Maybe (MVar String) -> RequestHandlers () testServerRequestHandlers mAsyncMVar = [ RequestHandler $ \_ (AddTwoNumbersReq x y) -> pure (x + y) , RequestHandler $ \_ (TestCallMeBackReq s) -> case mAsyncMVar of Nothing -> pure () Just mvar -> putMVar mvar s , AsyncRequestHandler $ \_ (TestAsyncReq v) -> maybe (pure ()) (\mvar -> putMVar mvar v) mAsyncMVar -- an async hello to the server generates an async hello to the client , AsyncRequestHandler $ \sState (AsyncHelloReq s) -> do sendMessage (connectionSocket sState) (AsyncHelloReq s) , RequestHandler $ \_ (RoundtripStringReq s) -> pure s , RequestHandler $ \_ (DelayMicrosecondsReq ms) -> do threadDelay ms pure () , RequestHandler $ \_ ThrowServerSideExceptionReq -> do _ <- error "test server exception" pure () ] -- test a simple client-to-server round-trip function execution testSimpleCall :: Assertion testSimpleCall = do readyVar <- newEmptyMVar server <- async (serve (testServerRequestHandlers Nothing) emptyServerState localHostAddr 0 (Just readyVar)) --wait for server to be ready (SockAddrInet port _) <- takeMVar readyVar conn <- connect [] localHostAddr port replicateM_ 5 $ do --make five AddTwo calls to shake out parallelism bugs x <- call conn (AddTwoNumbersReq 1 1) assertEqual "server request+response" (Right (2 :: Int)) x close conn cancel server --test that the client can proces a server-initiated asynchronous callback from the server testAsyncServerCall :: Assertion testAsyncServerCall = do portReadyVar <- newEmptyMVar receivedAsyncMessageVar <- newEmptyMVar let clientAsyncHandlers = [ClientAsyncRequestHandler (\(AsyncHelloReq s) -> putMVar receivedAsyncMessageVar s)] server <- async (serve (testServerRequestHandlers (Just receivedAsyncMessageVar)) emptyServerState localHostAddr 0 (Just portReadyVar)) (SockAddrInet port _) <- takeMVar portReadyVar conn <- connect clientAsyncHandlers localHostAddr port Right () <- asyncCall conn (AsyncHelloReq "welcome") asyncMessage <- takeMVar receivedAsyncMessageVar assertEqual "async message" "welcome" asyncMessage close conn cancel server emptyServerState :: () emptyServerState = () --test that the client can make a non-blocking call testAsyncClientCall :: Assertion testAsyncClientCall = do portReadyVar <- newEmptyMVar receivedAsyncMessageVar <- newEmptyMVar server <- async (serve (testServerRequestHandlers (Just receivedAsyncMessageVar)) emptyServerState localHostAddr 0 (Just portReadyVar)) (SockAddrInet port _) <- takeMVar portReadyVar conn <- connect [] localHostAddr port --send an async message, wait for an async response to confirm receipt Right () <- asyncCall conn (TestCallMeBackReq "hi server") asyncMessage <- takeMVar receivedAsyncMessageVar assertEqual "async message" "hi server" asyncMessage close conn cancel server testSyncClientCallTimeout :: Assertion testSyncClientCallTimeout = do readyVar <- newEmptyMVar server <- async (serve (testServerRequestHandlers Nothing) emptyServerState localHostAddr 0 (Just readyVar)) --wait for server to be ready (SockAddrInet port _) <- takeMVar readyVar conn <- connect [] localHostAddr port x <- callTimeout @_ @Int (Just 500) conn (DelayMicrosecondsReq 1000) assertEqual "client sync timeout" (Left TimeoutError) x close conn cancel server testSyncException :: Assertion testSyncException = do readyVar <- newEmptyMVar server <- async (serve (testServerRequestHandlers Nothing) emptyServerState localHostAddr 0 (Just readyVar)) (SockAddrInet port _) <- takeMVar readyVar conn <- connect [] localHostAddr port ret <- call conn ThrowServerSideExceptionReq case ret of Left (ExceptionError actualExc) -> assertBool "server-side exception" ("test server exception" `isPrefixOf` actualExc) Left otherErr -> assertFailure ("server-side error " <> show otherErr) Right () -> assertFailure "missed exception" close conn cancel server --throw large messages (> PIPE_BUF) at the server from multiple client connections to exercise the socket lock testMultithreadedClient :: Assertion testMultithreadedClient = do readyVar <- newEmptyMVar server <- async (serve (testServerRequestHandlers Nothing) emptyServerState localHostAddr 0 (Just readyVar)) (SockAddrInet port _) <- takeMVar readyVar conn <- connect [] localHostAddr port let bigString = replicate (1024 * 1000) 'x' replicateM_ 10 $ do ret <- call conn (RoundtripStringReq bigString) assertEqual "big string multithread" (Right bigString) ret --putStrLn "plus one" close conn cancel server testServerState :: Assertion testServerState = do readyVar <- newEmptyMVar let serverHandlers = [RequestHandler (\sState ChangeServerState -> atomically $ modifyTVar (connectionServerState sState) (+ 1)) ] serverState <- newTVarIO @Int 1 server <- async (serve serverHandlers serverState localHostAddr 0 (Just readyVar)) (SockAddrInet port _) <- takeMVar readyVar conn <- connect [] localHostAddr port ret <- call conn ChangeServerState assertEqual "server ret" (Right ()) ret serverState' <- readTVarIO serverState assertEqual "server state" 2 serverState' close conn cancel server --test that the request handler can throw a TimeoutError exception which is converted into a TimeoutError response testRequestHandlerThrowTimeout :: Assertion testRequestHandlerThrowTimeout = do readyVar <- newEmptyMVar let serverHandlers = [RequestHandler (\_ ThrowTimeout -> throw TimeoutException >> pure (1 :: Int) ) ] server <- async (serve serverHandlers () localHostAddr 0 (Just readyVar)) (SockAddrInet port _) <- takeMVar readyVar conn <- connect [] localHostAddr port ret <- call @_ @Int conn ThrowTimeout assertEqual "handler timeout exception" (Left TimeoutError) ret close conn cancel server