{-# LANGUAGE OverloadedStrings #-} module Main (main) where import qualified Network.JsonRpc.Server as S import Network.JsonRpc.Server ((:+:) (..)) import Internal ( request, defaultRq, defaultRsp , defaultIdErrRsp, nullIdErrRsp , version, result, rpcErr, method , params, id', array, rspToIdString) import qualified TestParallelism import Data.List (sortBy) import qualified Data.Vector as V import Data.Function (on) import qualified Data.Aeson as A import Data.Aeson ((.=)) import qualified Data.Aeson.Types as A import qualified Data.HashMap.Strict as H import Control.Applicative ((<$>)) import Control.Monad.Trans (liftIO) import Control.Monad.State (State, runState, lift, modify) import Control.Monad.Identity (Identity, runIdentity) import Test.HUnit hiding (State, Test) import Test.Framework (defaultMain, Test) import Test.Framework.Providers.HUnit (testCase) import Prelude hiding (subtract) main :: IO () main = defaultMain $ errorHandlingTests ++ otherTests errorHandlingTests :: [Test] errorHandlingTests = [ testCase "invalid JSON" $ assertSubtractResponse (A.String "5") $ nullIdErrRsp (-32700) , testCase "invalid JSON-RPC" $ assertSubtractResponse (A.object ["id" .= A.Number 10]) $ nullIdErrRsp (-32600) , testCase "empty batch call" $ assertSubtractResponse A.emptyArray $ nullIdErrRsp (-32600) , testCase "invalid batch element" $ removeErrMsg <$> callSubtractMethods (array [A.Bool True]) @?= Just (array [nullIdErrRsp (-32600)]) , testCase "wrong request version" $ assertSubtractResponse (defaultRq `version` Just "1.0") $ nullIdErrRsp (-32600) , testCase "wrong id type" $ assertSubtractResponse (defaultRq `id'` (Just $ A.Bool True)) $ nullIdErrRsp (-32600) , testCase "method not found" $ assertSubtractResponse (defaultRq `method` "add") (defaultIdErrRsp (-32601)) , testCase "wrong method name capitalization" $ assertSubtractResponse (defaultRq `method` "Subtract") (defaultIdErrRsp (-32601)) , testCase "missing required named argument" $ assertInvalidParams $ defaultRq `params` Just (A.object ["a" .= A.Number 1, "y" .= A.Number 20]) , testCase "missing required unnamed argument" $ assertInvalidParams $ defaultRq `method` "flipped subtract" `params` Just (array [A.Number 0]) , testCase "wrong argument type" $ assertInvalidParams $ defaultRq `params` Just (A.object ["x" .= A.Number 1, "y" .= A.String "2"]) , testCase "extra unnamed arguments" $ assertInvalidParams $ defaultRq `params` Just (array $ map A.Number [1, 2, 3]) , let req = defaultRq `id'` Nothing `method` "12345" in testCase "invalid notification" $ callSubtractMethods req @?= Nothing ] otherTests :: [Test] otherTests = [ testCase "encode RPC error" $ A.toJSON (S.rpcError (-1) "error") @?= rpcErr Nothing (-1) "error" , let err = S.rpcErrorWithData 1 "my message" errData testError = rpcErr (Just $ A.toJSON errData) 1 "my message" errData = ('\x03BB', [True], ()) in testCase "encode RPC error with data" $ A.toJSON err @?= testError , testCase "batch request" testBatch , testCase "batch notifications" testBatchNotifications , testCase "allow missing version" testAllowMissingVersion , testCase "no arguments" $ assertGetTimeResponse Nothing , testCase "empty argument array" $ assertGetTimeResponse $ Just A.emptyArray , testCase "empty argument A.object" $ assertGetTimeResponse $ Just A.emptyObject , let req = defaultRq `params` Just args args = A.object ["x" .= A.Number 10, "y" .= A.Number 20, "z" .= A.String "extra"] rsp = defaultRsp `result` A.Number (-10) in testCase "allow extra named argument" $ assertSubtractResponse req rsp , let req = defaultRq `params` (Just $ A.object [("x1", A.Number 500), ("x", A.Number 1000)]) rsp = defaultRsp `result` A.Number 1000 in testCase "use default named argument" $ assertSubtractResponse req rsp , let req = defaultRq `params` (Just $ array [A.Number 4]) rsp = defaultRsp `result` A.Number 4 in testCase "use default unnamed argument" $ assertSubtractResponse req rsp , testCase "string request ID" $ assertEqualId $ A.String "ID 5" , testCase "null request ID" $ assertEqualId A.Null , testCase "parallelize tasks" TestParallelism.testParallelizingTasks ] assertSubtractResponse :: A.Value -> A.Value -> Assertion assertSubtractResponse rq expectedRsp = removeErrMsg <$> rsp @?= Just expectedRsp where rsp = callSubtractMethods rq assertEqualId :: A.Value -> Assertion assertEqualId i = assertSubtractResponse (defaultRq `id'` Just i) (defaultRsp `id'` Just i) assertInvalidParams :: A.Value -> Assertion assertInvalidParams req = assertSubtractResponse req (defaultIdErrRsp (-32602)) testBatch :: Assertion testBatch = sortBy (compare `on` rspToIdString) <$> response @?= Just expected where expected = [nullIdErrRsp (-32600), rsp i1 2, rsp i2 4] where rsp i x = defaultRsp `id'` Just i `result` A.Number x response = fromArray =<< (removeErrMsg <$> callSubtractMethods (array requests)) requests = [rq (Just i1) 10 8, rq (Just i2) 24 20, rq Nothing 15 1, defaultRq `version` Just (A.String "abc")] where rq i x y = defaultRq `id'` i `params` toArgs x y toArgs :: Int -> Int -> Maybe A.Value toArgs x y = Just $ A.object ["x" .= x, "y" .= y] i1 = A.Number 1 i2 = A.Number 2 fromArray (A.Array v) = Just $ V.toList v fromArray _ = Nothing testBatchNotifications :: Assertion testBatchNotifications = runState response 0 @?= (Nothing, 10) where response = S.call (S.toMethods [incrementStateMethod]) $ A.encode rq rq = replicate 10 $ request Nothing "increment" Nothing testAllowMissingVersion :: Assertion testAllowMissingVersion = callSubtractMethods requestNoVersion @?= (Just $ defaultRsp `result` A.Number 1) where requestNoVersion = defaultRq `version` Nothing `params` Just (A.object ["x" .= A.Number 1]) incrementStateMethod :: S.Method (State Int) incrementStateMethod = S.toMethod "increment" f () where f :: S.RpcResult (State Int) () f = lift $ modify (+1) assertGetTimeResponse :: Maybe A.Value -> Assertion assertGetTimeResponse args = passed @? "unexpected RPC response" where passed = (expected ==) <$> rsp expected = Just $ defaultRsp `result` A.Number 100 req = defaultRq `method` "get_time_seconds" `params` args rsp = callGetTimeMethod req callSubtractMethods :: A.Value -> Maybe A.Value callSubtractMethods req = let methods :: S.Methods Identity methods = S.toMethods [subtractMethod, flippedSubtractMethod] rsp = S.call methods $ A.encode req in A.decode =<< runIdentity rsp callGetTimeMethod :: A.Value -> IO (Maybe A.Value) callGetTimeMethod req = let methods :: S.Methods IO methods = S.toMethods [getTimeMethod] rsp = S.call methods $ A.encode req in (A.decode =<<) <$> rsp subtractMethod :: S.Method Identity subtractMethod = S.toMethod "subtract" subtract (S.Required "x" :+: S.Optional "y" 0 :+: ()) flippedSubtractMethod :: S.Method Identity flippedSubtractMethod = S.toMethod "flipped subtract" (flip subtract) ps where ps = S.Optional "y" (-1000) :+: S.Required "x" :+: () subtract :: Int -> Int -> S.RpcResult Identity Int subtract x y = return (x - y) getTimeMethod :: S.Method IO getTimeMethod = S.toMethod "get_time_seconds" getTestTime () where getTestTime :: S.RpcResult IO Integer getTestTime = liftIO $ return 100 removeErrMsg :: A.Value -> A.Value removeErrMsg (A.Object rsp) = A.Object $ H.adjust removeMsg "error" rsp where removeMsg (A.Object err) = A.Object $ H.insert "message" "" $ H.delete "data" err removeMsg v = v removeErrMsg (A.Array rsps) = A.Array $ removeErrMsg `V.map` rsps removeErrMsg v = v