{-# Language DeriveDataTypeable #-}

-------------------------------------------------------------------
-- |
-- Module    : Network.MessagePackRpc.Client
-- Copyright : (c) Hideyuki Tanaka, 2010
-- License   : BSD3
--
-- Maintainer:  tanaka.hideyuki@gmail.com
-- Stability :  experimental
-- Portability: portable
--
-- This module is client library of MessagePack-RPC.
-- The specification of MessagePack-RPC is at <http://redmine.msgpack.org/projects/msgpack/wiki/RPCProtocolSpec>.
--
-- A simple example:
--
-- >import Network.MessagePackRpc.Client
-- >
-- >add :: RpcMethod (Int -> Int -> IO Int)
-- >add = method "add"
-- >
-- >main = do
-- >  conn <- connect "127.0.0.1" 1234
-- >  print =<< add conn 123 456
--
--------------------------------------------------------------------

module Network.MessagePackRpc.Client (
  -- * RPC connection
  Connection,
  connect,
  disconnect,
  
  -- * RPC error
  RpcError(..),
  
  -- * Call RPC method
  RpcMethod,
  call,
  method,
  ) where

import Control.Concurrent.MVar
import Control.Exception
import Control.Monad
import qualified Data.ByteString.Lazy as BL
import qualified Data.Conduit as C
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.Attoparsec as CA
import Data.Functor
import Data.MessagePack
import Data.Typeable
import Network
import System.IO
import System.Random

-- | RPC connection type
data Connection
  = Connection
    { connHandle :: MVar Handle }

-- | Connect to RPC server
connect :: String -- ^ Host name
           -> Int -- ^ Port number
           -> IO Connection -- ^ Connection
connect addr port = withSocketsDo $ do
  h <- connectTo addr (PortNumber $ fromIntegral port)
  mh <- newMVar h
  return $ Connection
    { connHandle = mh
    }

-- | Disconnect a connection
disconnect :: Connection -> IO ()
disconnect Connection { connHandle = mh } =
  hClose =<< takeMVar mh

-- | RPC error type
data RpcError
  = ServerError Object -- ^ Server error
  | ResultTypeError String -- ^ Result type mismatch
  | ProtocolError String -- ^ Protocol error
  deriving (Eq, Ord, Typeable)

instance Exception RpcError

instance Show RpcError where
  show (ServerError err) =
    "server error: " ++ show err
  show (ResultTypeError err) =
    "result type error: " ++ err
  show (ProtocolError err) =
    "protocol error: " ++ err

class RpcType r where
  rpcc :: Connection -> String -> [Object] -> r

fromObject' :: OBJECT o => Object -> o
fromObject' o =
  case tryFromObject o of
    Left err -> throw $ ResultTypeError err
    Right r -> r

instance OBJECT o => RpcType (IO o) where
  rpcc c m args = return . fromObject' =<< rpcCall c m (reverse args)

instance (OBJECT o, RpcType r) => RpcType (o -> r) where
  rpcc c m args arg = rpcc c m (toObject arg:args)

rpcCall :: Connection -> String -> [Object] -> IO Object
rpcCall Connection{ connHandle = mh } m args = withMVar mh $ \h -> do
  msgid <- (`mod`2^(30::Int)) <$> randomIO :: IO Int
  BL.hPutStr h $ pack (0 ::Int, msgid, m, args)
  hFlush h
  C.runResourceT $ CB.sourceHandle h C.$$ do
    (rtype, rmsgid, rerror, rresult) <- CA.sinkParser get
    when (rtype /= (1 :: Int)) $
      throw $ ProtocolError $ "response type is not 1 (got " ++ show rtype ++ ")"
    when (rmsgid /= msgid) $
      throw $ ProtocolError $ "message id mismatch: expect " ++ show msgid ++ ", but got " ++ show rmsgid
    case tryFromObject rerror of
      Left _ ->
        throw $ ServerError rerror
      Right () ->
        return rresult

-- | Call an RPC Method
call :: RpcType a =>
        Connection -- ^ Connection
        -> String -- ^ Method name
        -> a
call c m = rpcc c m []

-- | Create an RPC Method (call c m == method m c)
method :: RpcType a => String -> RpcMethod a
method c m = call m c

type RpcMethod a = Connection -> a