module DBus.Transport
(
Transport(..)
, TransportOpen(..)
, TransportListen(..)
, TransportError
, transportError
, transportErrorMessage
, transportErrorAddress
, SocketTransport
, socketTransportOptionBacklog
, socketTransportCredentials
) where
import Control.Exception
import qualified Data.ByteString
import Data.ByteString (ByteString)
import qualified Data.Map as Map
import Data.Typeable (Typeable)
import Foreign.C (CUInt)
import Network.Socket hiding (recv)
import Network.Socket.ByteString (sendAll, recv)
import qualified System.Info
import qualified Data.Serialize.Builder as Builder
import DBus
data TransportError = TransportError
{ transportErrorMessage :: String
, transportErrorAddress :: Maybe Address
}
deriving (Eq, Show, Typeable)
instance Exception TransportError
transportError :: String -> TransportError
transportError msg = TransportError msg Nothing
class Transport t where
data TransportOptions t :: *
transportDefaultOptions :: TransportOptions t
transportPut :: t -> ByteString -> IO ()
transportGet :: t -> Int -> IO ByteString
transportClose :: t -> IO ()
class Transport t => TransportOpen t where
transportOpen :: TransportOptions t -> Address -> IO t
class Transport t => TransportListen t where
data TransportListener t :: *
transportListen :: TransportOptions t -> Address -> IO (TransportListener t)
transportAccept :: TransportListener t -> IO t
transportListenerClose :: TransportListener t -> IO ()
transportListenerAddress :: TransportListener t -> Address
transportListenerUUID :: TransportListener t -> UUID
data SocketTransport = SocketTransport (Maybe Address) Socket
instance Transport SocketTransport where
data TransportOptions SocketTransport = SocketTransportOptions
{
socketTransportOptionBacklog :: Int
}
transportDefaultOptions = SocketTransportOptions 30
transportPut (SocketTransport addr s) bytes = catchIOException addr (sendAll s bytes)
transportGet (SocketTransport addr s) n = catchIOException addr (recvLoop s n)
transportClose (SocketTransport addr s) = catchIOException addr (sClose s)
recvLoop :: Socket -> Int -> IO ByteString
recvLoop s = loop Builder.empty where
chunkSize = 4096
loop acc n = if n > chunkSize
then do
chunk <- recv s chunkSize
let builder = Builder.append acc (Builder.fromByteString chunk)
loop builder (n Data.ByteString.length chunk)
else do
chunk <- recv s n
let builder = Builder.append acc (Builder.fromByteString chunk)
if Data.ByteString.length chunk == n
then return (Builder.toByteString builder)
else loop builder (n Data.ByteString.length chunk)
instance TransportOpen SocketTransport where
transportOpen _ a = case addressMethod a of
"unix" -> openUnix a
"tcp" -> openTcp a
method -> throwIO (transportError ("Unknown address method: " ++ show method))
{ transportErrorAddress = Just a
}
instance TransportListen SocketTransport where
data TransportListener SocketTransport = SocketTransportListener Address UUID Socket
transportListen opts a = do
uuid <- randomUUID
(a', sock) <- case addressMethod a of
"unix" -> listenUnix uuid a opts
"tcp" -> listenTcp uuid a opts
method -> throwIO (transportError ("Unknown address method: " ++ show method))
{ transportErrorAddress = Just a
}
return (SocketTransportListener a' uuid sock)
transportAccept (SocketTransportListener a _ s) = catchIOException (Just a) $ do
(s', _) <- accept s
return (SocketTransport Nothing s')
transportListenerClose (SocketTransportListener a _ s) = catchIOException (Just a) (sClose s)
transportListenerAddress (SocketTransportListener a _ _) = a
transportListenerUUID (SocketTransportListener _ uuid _) = uuid
socketTransportCredentials :: SocketTransport -> IO (CUInt, CUInt, CUInt)
socketTransportCredentials (SocketTransport a s) = catchIOException a (getPeerCred s)
openUnix :: Address -> IO SocketTransport
openUnix transportAddr = go where
params = addressParameters transportAddr
param key = Map.lookup key params
tooMany = "Only one of 'path' or 'abstract' may be specified for the\
\ 'unix' transport."
tooFew = "One of 'path' or 'abstract' must be specified for the\
\ 'unix' transport."
path = case (param "path", param "abstract") of
(Just x, Nothing) -> Right x
(Nothing, Just x) -> Right ('\x00' : x)
(Nothing, Nothing) -> Left tooFew
_ -> Left tooMany
go = case path of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just transportAddr
}
Right p -> catchIOException (Just transportAddr) $ do
sock <- socket AF_UNIX Stream defaultProtocol
connect sock (SockAddrUnix p)
return (SocketTransport (Just transportAddr) sock)
openTcp :: Address -> IO SocketTransport
openTcp transportAddr = go where
params = addressParameters transportAddr
param key = Map.lookup key params
hostname = maybe "localhost" id (param "host")
unknownFamily x = "Unknown socket family for TCP transport: " ++ show x
getFamily = case param "family" of
Just "ipv4" -> Right AF_INET
Just "ipv6" -> Right AF_INET6
Nothing -> Right AF_UNSPEC
Just x -> Left (unknownFamily x)
missingPort = "TCP transport requires the `port' parameter."
badPort x = "Invalid socket port for TCP transport: " ++ show x
getPort = case param "port" of
Nothing -> Left missingPort
Just x -> case readPortNumber x of
Just port -> Right port
Nothing -> Left (badPort x)
getAddresses family_ = getAddrInfo (Just (defaultHints
{ addrFlags = [AI_ADDRCONFIG]
, addrFamily = family_
, addrSocketType = Stream
})) (Just hostname) Nothing
openSocket [] = throwIO (transportError "openTcp: no addresses")
{ transportErrorAddress = Just transportAddr
}
openSocket (addr:addrs) = do
tried <- Control.Exception.try $ bracketOnError
(socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr))
sClose
(\sock -> do
connect sock (addrAddress addr)
return sock)
case tried of
Left err -> case addrs of
[] -> throwIO (transportError (show (err :: IOException)))
{ transportErrorAddress = Just transportAddr
}
_ -> openSocket addrs
Right sock -> return sock
go = case getPort of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just transportAddr
}
Right port -> case getFamily of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just transportAddr
}
Right family_ -> catchIOException (Just transportAddr) $ do
addrs <- getAddresses family_
sock <- openSocket (map (setPort port) addrs)
return (SocketTransport (Just transportAddr) sock)
listenUnix :: UUID -> Address -> TransportOptions SocketTransport -> IO (Address, Socket)
listenUnix uuid origAddr opts = getPath >>= go where
params = addressParameters origAddr
param key = Map.lookup key params
tooMany = "Only one of 'abstract', 'path', or 'tmpdir' may be\
\ specified for the 'unix' transport."
tooFew = "One of 'abstract', 'path', or 'tmpdir' must be specified\
\ for the 'unix' transport."
getPath = case (param "abstract", param "path", param "tmpdir") of
(Just path, Nothing, Nothing) -> let
addr = address_ "unix"
[ ("abstract", path)
, ("guid", formatUUID uuid)
]
in return (Right (addr, '\x00' : path))
(Nothing, Just path, Nothing) -> let
addr = address_ "unix"
[ ("path", path)
, ("guid", formatUUID uuid)
]
in return (Right (addr, path))
(Nothing, Nothing, Just x) -> do
let fileName = x ++ "/haskell-dbus-" ++ formatUUID uuid
let (addrParams, path) = if System.Info.os == "linux"
then ([("abstract", fileName)], '\x00' : fileName)
else ([("path", fileName)], fileName)
let addr = address_ "unix" (addrParams ++ [("guid", formatUUID uuid)])
return (Right (addr, path))
(Nothing, Nothing, Nothing) -> return (Left tooFew)
_ -> return (Left tooMany)
go path = case path of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just origAddr
}
Right (addr, p) -> catchIOException (Just addr) $ do
sock <- socket AF_UNIX Stream defaultProtocol
bindSocket sock (SockAddrUnix p)
Network.Socket.listen sock (socketTransportOptionBacklog opts)
return (addr, sock)
listenTcp :: UUID -> Address -> TransportOptions SocketTransport -> IO (Address, Socket)
listenTcp uuid origAddr opts = go where
params = addressParameters origAddr
param key = Map.lookup key params
unknownFamily x = "Unknown socket family for TCP transport: " ++ show x
getFamily = case param "family" of
Just "ipv4" -> Right AF_INET
Just "ipv6" -> Right AF_INET6
Nothing -> Right AF_UNSPEC
Just x -> Left (unknownFamily x)
badPort x = "Invalid socket port for TCP transport: " ++ show x
getPort = case param "port" of
Nothing -> Right 0
Just x -> case readPortNumber x of
Just port -> Right port
Nothing -> Left (badPort x)
paramBind = case param "bind" of
Just "*" -> Nothing
Just x -> Just x
Nothing -> case param "host" of
Just x -> Just x
Nothing -> Just "localhost"
getAddresses family_ = getAddrInfo (Just (defaultHints
{ addrFlags = [AI_ADDRCONFIG, AI_PASSIVE]
, addrFamily = family_
, addrSocketType = Stream
})) paramBind Nothing
bindAddrs _ [] = throwIO (transportError "listenTcp: no addresses")
{ transportErrorAddress = Just origAddr
}
bindAddrs sock (addr:addrs) = do
tried <- Control.Exception.try (bindSocket sock (addrAddress addr))
case tried of
Left err -> case addrs of
[] -> throwIO (transportError (show (err :: IOException)))
{ transportErrorAddress = Just origAddr
}
_ -> bindAddrs sock addrs
Right _ -> return ()
sockAddr (PortNum port) = address_ "tcp" p where
p = baseParams ++ hostParam ++ familyParam
baseParams =
[ ("port", show port)
, ("guid", formatUUID uuid)
]
hostParam = case param "host" of
Just x -> [("host", x)]
Nothing -> []
familyParam = case param "family" of
Just x -> [("family", x)]
Nothing -> []
go = case getPort of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just origAddr
}
Right port -> case getFamily of
Left err -> throwIO (transportError err)
{ transportErrorAddress = Just origAddr
}
Right family_ -> catchIOException (Just origAddr) $ do
sockAddrs <- getAddresses family_
sock <- (bracketOnError
(do
sock <- socket family_ Stream defaultProtocol
setSocketOption sock ReuseAddr 1
return sock)
sClose
(\sock -> do
bindAddrs sock (map (setPort port) sockAddrs)
return sock))
Network.Socket.listen sock (socketTransportOptionBacklog opts)
sockPort <- socketPort sock
return (sockAddr sockPort, sock)
catchIOException :: Maybe Address -> IO a -> IO a
catchIOException addr io = do
tried <- try io
case tried of
Right a -> return a
Left err -> throwIO (transportError (show (err :: IOException)))
{ transportErrorAddress = addr
}
address_ :: String -> [(String, String)] -> Address
address_ method params = addr where
Just addr = address method (Map.fromList params)
setPort :: PortNumber -> AddrInfo -> AddrInfo
setPort port info = case addrAddress info of
(SockAddrInet _ x) -> info { addrAddress = SockAddrInet port x }
(SockAddrInet6 _ x y z) -> info { addrAddress = SockAddrInet6 port x y z }
_ -> info
readPortNumber :: String -> Maybe PortNumber
readPortNumber s = do
case dropWhile (\c -> c >= '0' && c <= '9') s of
[] -> return ()
_ -> Nothing
let word = read s :: Integer
if word > 0 && word <= 65535
then Just (fromInteger word)
else Nothing