-- Copyright (C) 2009-2010 John Millikin <jmillikin@gmail.com>
-- 
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
-- 
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
-- 
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveDataTypeable #-}
module DBus.Connection (
	  Connection
	, connectionAddress
	, connectionUUID
	, ConnectionError (..)
	, connect
	, connectFirst
	, connectionClose
	, send
	, receive
	) where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL
import qualified Control.Concurrent as C
import qualified DBus.Address as A
import qualified DBus.Message as M
import qualified DBus.UUID as UUID
import qualified Data.ByteString.Lazy as L
import Data.Word (Word32)
import qualified Network as N
import qualified Data.Map as Map
import qualified Network.Socket as NS
import qualified Text.Parsec as P
import Control.Monad (unless)
import Data.Binary.Get (runGet, getWord16host)
import Data.Binary.Put (runPut, putWord16be)
import qualified System.IO as I
import qualified Control.Exception as E
import Data.Typeable (Typeable)
import qualified DBus.Authentication as Auth
import qualified DBus.Wire as W
data Connection = Connection
	{ connectionAddress    :: A.Address
	, connectionTransport  :: Transport
	, connectionSerialMVar :: C.MVar M.Serial
	, connectionReadMVar   :: C.MVar ()
	, connectionUUID       :: UUID.UUID
	}
instance Show Connection where
	showsPrec d con = showParen (d > 10) strCon where
		addr = A.strAddress $ connectionAddress con
		strCon = s "<Connection " . shows addr . s ">"
		s = showString
-- | A 'Transport' is anything which can send and receive bytestrings,
-- typically via a socket.
data Transport = Transport
	{ transportSend :: L.ByteString -> IO ()
	, transportRecv :: Word32 -> IO L.ByteString
	, transportClose :: IO ()
	}
connectTransport :: A.Address -> IO Transport
connectTransport a = transport' (A.addressMethod a) a where
	transport' "unix" = unix
	transport' "tcp"  = tcp
	transport' _      = E.throwIO . UnknownMethod
unix :: A.Address -> IO Transport
unix a = port >>= N.connectTo "localhost" >>= handleTransport where
	params = A.addressParameters a
	path = Map.lookup "path" params
	abstract = Map.lookup "abstract" params
	
	tooMany = "Only one of `path' or `abstract' may be specified for the\
	          \ `unix' transport."
	tooFew = "One of `path' or `abstract' must be specified for the\
	         \ `unix' transport."
	
	port = fmap N.UnixSocket path'
	path' = case (path, abstract) of
		(Just _, Just _) -> E.throwIO $ BadParameters a tooMany
		(Nothing, Nothing) -> E.throwIO $ BadParameters a tooFew
		(Just x, Nothing) -> return $ TL.unpack x
		(Nothing, Just x) -> return $ '\x00' : TL.unpack x
tcp :: A.Address -> IO Transport
tcp a = openHandle >>= handleTransport where
	params = A.addressParameters a
	openHandle = do
		port <- getPort
		family <- getFamily
		addresses <- getAddresses family
		socket <- openSocket port addresses
		NS.socketToHandle socket I.ReadWriteMode
	hostname = maybe "localhost" TL.unpack $ Map.lookup "host" params
	unknownFamily x = TL.concat ["Unknown socket family for TCP transport: ", x]
	getFamily = case Map.lookup "family" params of
		Just "ipv4" -> return NS.AF_INET
		Just "ipv6" -> return NS.AF_INET6
		Nothing     -> return NS.AF_UNSPEC
		Just x      -> E.throwIO $ BadParameters a $ unknownFamily x
	missingPort = "TCP transport requires the ``port'' parameter."
	badPort x = TL.concat ["Invalid socket port for TCP transport: ", x]
	getPort = case Map.lookup "port" params of
		Nothing -> E.throwIO $ BadParameters a missingPort
		Just x -> case P.parse parseWord16 "" (TL.unpack x) of
			Right x' -> return $ NS.PortNum x'
			Left  _  -> E.throwIO $ BadParameters a $ badPort x
	parseWord16 = do
		chars <- P.many1 P.digit
		P.eof
		let value = read chars :: Integer
		unless (value > 0 && value <= 65535) $
			P.parserFail "bad port" >> return ()
		let word = fromIntegral value
		return $ runGet getWord16host (runPut (putWord16be word))
	getAddresses family = do
		let hints = NS.defaultHints
			{ NS.addrFlags = [NS.AI_ADDRCONFIG]
			, NS.addrFamily = family
			, NS.addrSocketType = NS.Stream
			}
		NS.getAddrInfo (Just hints) (Just hostname) Nothing
	setPort port (NS.SockAddrInet  _ x)     = NS.SockAddrInet port x
	setPort port (NS.SockAddrInet6 _ x y z) = NS.SockAddrInet6 port x y z
	setPort _    addr                       = addr
	openSocket _ [] = E.throwIO $ NoWorkingAddress [a]
	openSocket port (addr:addrs) = E.catch (openSocket' port addr) $
		\(E.SomeException _) -> openSocket port addrs
	openSocket' port addr = do
		sock <- NS.socket (NS.addrFamily addr)
		                  (NS.addrSocketType addr)
		                  (NS.addrProtocol addr)
		NS.connect sock . setPort port . NS.addrAddress $ addr
		return sock
handleTransport :: I.Handle -> IO Transport
handleTransport h = do
	I.hSetBuffering h I.NoBuffering
	I.hSetBinaryMode h True
	return $ Transport (L.hPut h) (L.hGet h . fromIntegral) (I.hClose h)
data ConnectionError
	= InvalidAddress Text
	| BadParameters A.Address Text
	| UnknownMethod A.Address
	| NoWorkingAddress [A.Address]
	deriving (Show, Typeable)

instance E.Exception ConnectionError
-- | Open a connection to some address, using a given authentication
-- mechanism. If the connection fails, a 'ConnectionError' will be thrown.
connect :: Auth.Mechanism -> A.Address -> IO Connection
connect mechanism a = do
	t <- connectTransport a
	let getByte = L.head `fmap` transportRecv t 1
	uuid <- Auth.authenticate mechanism (transportSend t) getByte
	readLock <- C.newMVar ()
	serialMVar <- C.newMVar M.firstSerial
	return $ Connection a t serialMVar readLock uuid
-- | Try to open a connection to various addresses, returning the first
-- connection which could be successfully opened.
connectFirst :: [(Auth.Mechanism, A.Address)] -> IO Connection
connectFirst orig = connectFirst' orig where
	allAddrs = [a | (_, a) <- orig]
	connectFirst'     [] = E.throwIO $ NoWorkingAddress allAddrs
	connectFirst' ((mech, a):as) = E.catch (connect mech a) $
		\(E.SomeException _) -> connectFirst' as
-- | Close an open connection. Once closed, the 'Connection' is no longer
-- valid and must not be used.
connectionClose :: Connection -> IO ()
connectionClose = transportClose . connectionTransport
-- | Send a single message, with a generated 'M.Serial'. The second parameter
-- exists to prevent race conditions when registering a reply handler; it
-- receives the serial the message /will/ be sent with, before it's actually
-- sent.
--
-- Only one message may be sent at a time; if multiple threads attempt to
-- send messages in parallel, one will block until after the other has
-- finished.
send :: M.Message a => Connection -> (M.Serial -> IO b) -> a
     -> IO (Either W.MarshalError b)
send (Connection _ t mvar _ _) io msg = withSerial mvar $ \serial ->
	case W.marshalMessage W.LittleEndian serial msg of
		Right bytes -> do
			x <- io serial
			transportSend t bytes
			return $ Right x
		Left  err   -> return $ Left err
withSerial :: C.MVar M.Serial -> (M.Serial -> IO a) -> IO a
withSerial m io = E.block $ do
	s <- C.takeMVar m
	let s' = M.nextSerial s
	x <- E.unblock (io s) `E.onException` C.putMVar m s'
	C.putMVar m s'
	return x
-- | Receive the next message from the connection, blocking until one is
-- available.
--
-- Only one message may be received at a time; if multiple threads attempt
-- to receive messages in parallel, one will block until after the other has
-- finished.
receive :: Connection -> IO (Either W.UnmarshalError M.ReceivedMessage)
receive (Connection _ t _ lock _) = C.withMVar lock $ \_ ->
	W.unmarshalMessage $ transportRecv t