-- Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
-- 
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
-- 
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
-- 
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE TypeFamilies #-}
module DBus.Client (
	  module DBus.Bus
	, module DBus.Types
	, module DBus.Message
	  -- * Clients
	, Client
	, C.Connection
	, clientName
	, newClient
	, DBus
	, DBusException
	, runDBus
	, getClient
	, processMessage
	, send
	, send_
	, receive
	, mainLoop
	, call
	, callBlocking
	, callBlocking_
	-- * Handling signals
	, onSignal
	-- * Name reservation
	, NR.RequestNameFlag (..)
	, NR.RequestNameReply (..)
	, NR.ReleaseNameReply (..)
	, requestName
	, releaseName
	, requestName_
	, releaseName_
	  -- * Exporting local objects
	, Object (..)
	, Interface (..)
	, Member (..)
	, Method (..)
	, export
	, object
	, interface
	, method
	  -- ** Responding to method calls
	, MethodCtx (..)
	, replyReturn
	, replyError
	, Proxy (..)
	, callProxy
	, callProxyBlocking
	, callProxyBlocking_
	, onProxySignal
	) where
import DBus.Bus
import DBus.Types
import DBus.Message

import qualified DBus.Connection as C
import qualified DBus.Constants as Const
import qualified DBus.Introspection as I
import qualified DBus.MatchRule as MR
import qualified DBus.Message as M
import qualified DBus.NameReservation as NR
import qualified DBus.Types as T
import qualified DBus.Wire as W
import qualified Control.Concurrent.MVar as MV
import qualified Data.Map as Map
import Control.Monad (liftM, ap, forever)
import Control.Monad.IO.Class (liftIO)
import qualified Control.Monad.IO.Class as MIO
import qualified Control.Monad.Reader as R
import qualified Control.Applicative as A
import Data.Typeable (Typeable)
import qualified Control.Exception as Exc
import qualified Control.Monad.Error as E
import Data.Maybe (isJust)
import qualified Data.Set as Set
import Data.Monoid (mconcat)
-- | 'Client's are opaque handles to an open connection and other internal
-- state.
data Client = Client
	{ clientConnection :: C.Connection
	, clientName :: T.BusName
	, clientCallbacks :: MV.MVar (Map.Map M.Serial MessageHandler)
	, clientObjects :: MV.MVar (Map.Map T.ObjectPath Object)
	, clientSignalHandlers :: MV.MVar [MessageHandler]
	}
type MessageHandler = (M.ReceivedMessage -> DBus ())
-- | Create a new 'Client' from an open connection and bus name. The weird
-- signature allows @newClient@ to use the computations in "DBus.Bus"
-- directly, without unpacking:
--
-- @
-- client <- newClient =<< 'getSessionBus'
-- @
--
-- Only one client should be created for any given connection. Otherwise,
-- they will compete to receive messages.
newClient :: (C.Connection, T.BusName) -> IO Client
newClient (c, name) = do
	callbacks <- MV.newMVar Map.empty
	objects <- MV.newMVar Map.empty
	signals <- MV.newMVar []
	let client = Client c name callbacks objects signals
	liftIO $ MV.modifyMVar_ objects $ return . Map.insert "/" rootObject
	return client
newtype DBus a = DBus { unDBus :: R.ReaderT Client IO a }

instance Monad DBus where
	return = DBus . return
	(>>=) (DBus m) f = DBus $ m >>= unDBus . f

instance MIO.MonadIO DBus where
	liftIO = DBus . MIO.liftIO

instance Functor DBus where
	fmap = liftM

instance A.Applicative DBus where
	pure = return
	(<*>) = ap
data DBusException
	= MarshalFailed W.MarshalError
	| UnmarshalFailed W.UnmarshalError
	| MethodCallFailed M.Error
	| InvalidRequestNameReply M.MethodReturn
	| InvalidReleaseNameReply M.MethodReturn
	deriving (Show, Eq, Typeable)
instance Exc.Exception DBusException
instance E.MonadError DBus where
	type E.ErrorType DBus = DBusException
	throwError = MIO.liftIO . Exc.throwIO
	catchError dbus h = do
		c <- getClient
		liftIO $ Exc.catch
			(runDBus c dbus)
			(runDBus c . h)
-- | Run a DBus computation with the given client callbacks. Errors
-- encountered while running will be thrown as exceptions, using the
-- 'DBusException' type.
--
-- Use the 'E.MonadError' instance for 'DBus' to handle errors inside
-- the computation.
runDBus :: Client -> DBus a -> IO a
runDBus c (DBus m) = R.runReaderT m c

getClient :: DBus Client
getClient = DBus R.ask

getConnection :: DBus C.Connection
getConnection = fmap clientConnection getClient
-- | Run message handlers with the received message. If any method reply
-- callbacks or signal handlers are found, they will be run in the current
-- thread.
processMessage :: M.ReceivedMessage -> DBus ()
processMessage received = p received where
	p (M.ReceivedUnknown _ _ _) = return ()
	p (M.ReceivedMethodReturn _ _ msg) = reply $ M.methodReturnSerial msg
	p (M.ReceivedError _ _ msg) = reply $ M.errorSerial msg
	p (M.ReceivedSignal _ _ _) = do
		mvar <- fmap clientSignalHandlers getClient
		handlers <- liftIO $ MV.readMVar mvar
		mapM_ ($ received) handlers
	p (M.ReceivedMethodCall _ _ msg) = do
		mvar <- fmap clientObjects getClient
		objects <- liftIO $ MV.readMVar mvar
		case findMethod objects msg of
			Just (obj, m) -> onMethodCall obj m received
			Nothing -> unknownMethod received
	reply s = onReply s received
-- | A wrapper around 'C.send'.
send :: M.Message msg => (M.Serial -> DBus a) -> msg -> DBus a
send onSerial msg = do
	c <- getConnection
	client <- getClient
	sent <- liftIO $ C.send c (runDBus client . onSerial) msg
	case sent of
		Left err -> E.throwError $ MarshalFailed err
		Right a -> return a

-- | A wrapper around 'C.send', which does not allow the message serial
-- to be recorded. This is a useful shortcut when sending messages which
-- are not expected to receive a reply.
send_ :: M.Message msg => msg -> DBus ()
send_ = send (const $ return ())
-- | A wrapper around 'C.receive'.
receive :: DBus M.ReceivedMessage
receive = do
	c <- getConnection
	parsed <- liftIO $ C.receive c
	case parsed of
		Left err -> E.throwError $ UnmarshalFailed err
		Right msg -> return msg
-- | Run in a loop forever, processing messages.
--
-- This is commonly run in a separate thread, ie
--
-- > client <- newClient =<< getSessionBus
-- > forkIO $ runDBus client mainLoop
mainLoop :: DBus ()
mainLoop = forever $ receive >>= processMessage
-- | Perform an asynchronous method call. One of the provided computations
-- will be performed depending on what message type the destination sends
-- back.
call :: M.MethodCall
     -> (M.Error -> DBus ())
     -> (M.MethodReturn -> DBus ())
     -> DBus ()
call msg onError onReturn = send addCallback msg where
	cb (M.ReceivedError _ _ msg') = onError msg'
	cb (M.ReceivedMethodReturn _ _ msg') = onReturn msg'
	cb _ = return ()
	
	addCallback s = do
		mvar <- fmap clientCallbacks getClient
		liftIO $ MV.modifyMVar_ mvar $ return . Map.insert s cb
onReply :: M.Serial -> M.ReceivedMessage -> DBus ()
onReply serial msg = do
	mvar <- fmap clientCallbacks getClient
	maybeCB <- liftIO $ MV.modifyMVar mvar $ \callbacks -> let
		x = Map.lookup serial callbacks
		callbacks' = if isJust x
			then Map.delete serial callbacks
			else callbacks
		in return (callbacks', x)
	case maybeCB of
		Just cb -> cb msg
		Nothing -> return ()
-- | Sends a method call, and then blocks until a reply is received. Use
-- this when the receive/process loop is running in a separate thread.
callBlocking :: M.MethodCall -> DBus (Either M.Error M.MethodReturn)
callBlocking msg = do
	mvar <- liftIO $ MV.newEmptyMVar
	call msg
		(liftIO . MV.putMVar mvar . Left)
		(liftIO . MV.putMVar mvar . Right)
	liftIO $ MV.takeMVar mvar

-- | A variant of 'callBlocking', which throws an exception if the
-- remote client returns 'M.Error'.
callBlocking_ :: M.MethodCall -> DBus M.MethodReturn
callBlocking_ msg = do
	reply <- callBlocking msg
	case reply of
		Left err -> E.throwError $ MethodCallFailed err
		Right x -> return x
-- | Perform some computation every time this client receives a matching
-- signal.
onSignal :: MR.MatchRule
	 -> (T.BusName -> M.Signal -> DBus ())
	 -> DBus ()
onSignal rule h = addHandler where
	rule' = rule { MR.matchType = Just MR.Signal }
	
	handler msg@(M.ReceivedSignal _ (Just sender) signal)
		| MR.matches rule' msg = h sender signal
	handler _ = return ()
	
	addHandler = do
		callBlocking_ $ MR.addMatch rule'
		mvar <- fmap clientSignalHandlers getClient
		liftIO $ MV.modifyMVar_ mvar $ return . (handler :)
requestName :: T.BusName
            -> [NR.RequestNameFlag]
            -> (M.Error -> DBus ())
            -> (NR.RequestNameReply -> DBus ())
            -> DBus ()
requestName name flags onError callback =
	call (NR.requestName name flags) onError $ \reply -> 
	case NR.mkRequestNameReply reply of
		Nothing -> E.throwError $ InvalidRequestNameReply reply
		Just x -> callback x
releaseName :: T.BusName
            -> (M.Error -> DBus ())
            -> (NR.ReleaseNameReply -> DBus ())
            -> DBus ()
releaseName name onError callback =
	call (NR.releaseName name) onError $ \reply ->
	case NR.mkReleaseNameReply reply of
		Nothing -> E.throwError $ InvalidReleaseNameReply reply
		Just x -> callback x
requestName_ :: T.BusName -> [NR.RequestNameFlag] -> DBus NR.RequestNameReply
requestName_ name flags = do
	reply <- callBlocking_ $ NR.requestName name flags
	case NR.mkRequestNameReply reply of
		Nothing -> E.throwError $ InvalidRequestNameReply reply
		Just x -> return x
releaseName_ :: T.BusName -> DBus NR.ReleaseNameReply
releaseName_ name = do
	reply <- callBlocking_ $ NR.releaseName name
	case NR.mkReleaseNameReply reply of
		Nothing -> E.throwError $ InvalidReleaseNameReply reply
		Just x -> return x
newtype Object = Object (Map.Map T.InterfaceName Interface)
newtype Interface = Interface (Map.Map T.MemberName Member)
data Member
	= MemberMethod Method
	| MemberSignal T.Signature
data Method = Method T.Signature T.Signature (MethodCtx -> DBus ())
-- | Export a set of interfaces on the bus. Whenever a method call is
-- received which matches the object's path, interface, and member name,
-- one of its members will be called.
--
-- Exported objects automatically implement the
-- @org.freedesktop.DBus.Introspectable@ interface.
export :: T.ObjectPath -> Object -> DBus ()
export path obj = do
	let obj' = addIntrospectable path obj
	mvar <- fmap clientObjects getClient
	liftIO $ MV.modifyMVar_ mvar $ return . Map.insert path obj'
object :: [(T.InterfaceName, Interface)] -> Object
object = Object . Map.fromList

interface :: [(T.MemberName, Member)] -> Interface
interface = Interface . Map.fromList

method :: T.Signature -- ^ Input signature
       -> T.Signature -- ^ Output signature
       -> (MethodCtx -> DBus ()) -- ^ Implementation
       -> Member
method inSig outSig cb = MemberMethod $ Method inSig outSig cb
data MethodCtx = MethodCtx
	{ methodCtxObject :: Object
	, methodCtxMethod :: Method
	, methodCtxSerial :: M.Serial
	, methodCtxSender :: Maybe T.BusName
	, methodCtxFlags  :: Set.Set M.Flag
	, methodCtxBody   :: [T.Variant]
	}
-- | Send a successful return reply for a method call.
replyReturn :: MethodCtx -> [T.Variant] -> DBus ()
replyReturn call' body = if valid then sendReply else sendError where
	sendError = replyError call' Const.errorFailed
		[T.toVariant ("Method return didn't match signature." :: String)]
	
	sendReply = send_ $ M.MethodReturn
		(methodCtxSerial call')
		(methodCtxSender call')
		body
	
	(Method _ outSig _) = methodCtxMethod call'
	valid = listSig body == Just outSig
replyError :: MethodCtx -> T.ErrorName -> [T.Variant] -> DBus ()
replyError call' name body = send_ $ M.Error
	name
	(methodCtxSerial call')
	(methodCtxSender call')
	body
unknownMethod :: M.ReceivedMessage -> DBus ()
unknownMethod msg = send_ errorMsg where
	M.ReceivedMethodCall serial sender _ = msg
	errorMsg = M.Error
		Const.errorUnknownMethod
		serial sender
		[]
findMethod :: Map.Map T.ObjectPath Object -> M.MethodCall -> Maybe (Object, Method)
findMethod objects call' = do
	Object obj <- Map.lookup (M.methodCallPath call') objects
	ifaceName <- M.methodCallInterface call'
	Interface iface <- Map.lookup ifaceName obj
	member <- Map.lookup (M.methodCallMember call') iface
	case member of
		MemberMethod m -> return (Object obj, m)
		_ -> Nothing
onMethodCall :: Object -> Method -> M.ReceivedMessage -> DBus ()
onMethodCall obj method' received = runCall where
	M.ReceivedMethodCall serial sender msg = received
	sig = listSig $ M.methodCallBody msg
	Method inSig _ cb = method'
	
	call' = MethodCtx obj method' serial sender
		(M.methodCallFlags msg)
		(M.methodCallBody msg)
	
	runCall = if sig == Just inSig
		then cb call'
		else replyError call' Const.errorInvalidArgs []
addIntrospectable :: T.ObjectPath -> Object -> Object
addIntrospectable path (Object ifaces) = Object ifaces' where
	ifaces' = Map.insertWith (\_ x -> x) name iface ifaces
	name = Const.interfaceIntrospectable
	iface = interface [("Introspect", impl)]
	impl = method "" "s" $ \call' -> do
		let Just xml = I.toXML . introspect path . methodCtxObject $ call'
		replyReturn call' [T.toVariant xml]
introspect :: T.ObjectPath -> Object -> I.Object
introspect path obj = I.Object path interfaces [] where
	Object ifaceMap = obj
	interfaces = map introspectIface (Map.toList ifaceMap)
	
	introspectIface :: (T.InterfaceName, Interface) -> I.Interface
	introspectIface (name, iface) = I.Interface name methods signals [] where
		Interface memberMap = iface
		members = Map.toList memberMap
		methods = concatMap introspectMethod members
		signals = concatMap introspectSignal members
	
	introspectMethod :: (T.MemberName, Member) -> [I.Method]
	introspectMethod (name, (MemberMethod (Method inSig outSig _))) =
		[I.Method name
			(map introspectParam (T.signatureTypes inSig))
			(map introspectParam (T.signatureTypes outSig))]
	introspectMethod _ = []
	
	introspectSignal :: (T.MemberName, Member) -> [I.Signal]
	introspectSignal (name, (MemberSignal sig)) = [I.Signal name
		(map introspectParam (T.signatureTypes sig))]
	introspectSignal _ = []
	
	introspectParam = I.Parameter "" . T.mkSignature_ . T.typeCode
rootObject :: Object
rootObject = object [(ifaceName, interface [(memberName, impl)])] where
	ifaceName = Const.interfaceIntrospectable
	memberName =  "Introspect"
	
	methodXML = I.Method memberName [] [I.Parameter "xml" "s"]
	ifaceXML = I.Interface ifaceName [methodXML] [] []
	
	impl = method "" "s" $ \call' -> do
		mvar <- fmap clientObjects getClient
		paths <- liftIO $ fmap Map.keys $ MV.readMVar mvar
		
		let paths' = filter (/= "/") paths
		let Just xml = I.toXML $ I.Object "/" [ifaceXML]
			[I.Object p [] [] | p <- paths']
		replyReturn call' [T.toVariant xml]
data Proxy = Proxy
	{ proxyName :: T.BusName
	, proxyObjectPath :: T.ObjectPath
	, proxyInterface :: T.InterfaceName
	}
	deriving (Show, Eq)
-- | As 'call', except that the proxy's information is used to
-- build the message.
callProxy :: Proxy -> T.MemberName -> [M.Flag] -> [T.Variant]
          -> (M.Error -> DBus ())
          -> (M.MethodReturn -> DBus ())
          -> DBus ()
callProxy proxy name flags body onError onReturn = let
	msg = buildMethodCall proxy name flags body
	in call msg onError onReturn
-- | As 'callBlocking', except that the proxy's information is used
-- to build the message.
callProxyBlocking :: Proxy -> T.MemberName -> [M.Flag] -> [T.Variant]
                  -> DBus (Either M.Error M.MethodReturn)
callProxyBlocking proxy name flags body =
	callBlocking $ buildMethodCall proxy name flags body
-- | As 'callBlocking_', except that the proxy's information is used
-- to build the message.
callProxyBlocking_ :: Proxy -> T.MemberName -> [M.Flag] -> [T.Variant]
                   -> DBus M.MethodReturn
callProxyBlocking_ proxy name flags body =
	callBlocking_ $ buildMethodCall proxy name flags body
-- | As 'onSIgnal', except that the proxy's information is used
-- to build the match rule.
onProxySignal :: Proxy -> T.MemberName -> (M.Signal -> DBus ())
             -> DBus ()
onProxySignal proxy member handler = onSignal rule handler' where
	Proxy dest path iface = proxy
	rule = MR.MatchRule
		{ MR.matchType = Nothing
		, MR.matchSender = Just dest
		, MR.matchInterface = Just iface
		, MR.matchMember = Just member
		, MR.matchPath = Just path
		, MR.matchDestination = Nothing
		, MR.matchParameters = []
		}
	handler' _ msg = handler msg
buildMethodCall :: Proxy -> T.MemberName -> [M.Flag] -> [T.Variant]
                -> M.MethodCall
buildMethodCall proxy name flags body = msg where
	Proxy dest path iface = proxy
	msg = M.MethodCall path name (Just iface) (Just dest)
		(Set.fromList flags) body
listSig :: [T.Variant] -> Maybe T.Signature
listSig = T.mkSignature . mconcat . map (T.typeCode . T.variantType)