{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE OverloadedStrings #-}

module DBus.Bus
        ( getBus
        , getFirstBus
        , getSystemBus
        , getSessionBus
        , getStarterBus
        ) where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL


import qualified Control.Exception as E
import Control.Monad (when)
import Data.Maybe (fromJust, isNothing)
import qualified Data.Set as Set
import System.Environment (getEnv)

import qualified DBus.Address as A
import qualified DBus.Authentication as Auth
import qualified DBus.Connection as C
import DBus.Constants (dbusName, dbusPath, dbusInterface)
import qualified DBus.Message as M
import qualified DBus.Types as T
import DBus.Util (fromRight)

busForConnection :: C.Connection -> IO (C.Connection, T.BusName)
busForConnection c = sendHello c >>= return . (,) c

-- | Similar to 'C.connect', but additionally sends @Hello@ messages to the
-- central bus.

getBus :: Auth.Mechanism -> A.Address -> IO (C.Connection, T.BusName)
getBus = ((busForConnection =<<) .) . C.connect

-- | Similar to 'C.connectFirst', but additionally sends @Hello@ messages to
-- the central bus.

getFirstBus :: [(Auth.Mechanism, A.Address)] -> IO (C.Connection, T.BusName)
getFirstBus = (busForConnection =<<) . C.connectFirst

-- | Connect to the bus specified in the environment variable
-- @DBUS_SYSTEM_BUS_ADDRESS@, or to
-- @unix:path=\/var\/run\/dbus\/system_bus_socket@ if @DBUS_SYSTEM_BUS_ADDRESS@
-- is not set.

getSystemBus :: IO (C.Connection, T.BusName)
getSystemBus = getBus' $ fromEnv `E.catch` noEnv where
        defaultAddr = "unix:path=/var/run/dbus/system_bus_socket"
        fromEnv = getEnv "DBUS_SYSTEM_BUS_ADDRESS"
        noEnv (E.SomeException _) = return defaultAddr

-- | Connect to the bus specified in the environment variable
-- @DBUS_SESSION_BUS_ADDRESS@, which must be set.

getSessionBus :: IO (C.Connection, T.BusName)
getSessionBus = getBus' $ getEnv "DBUS_SESSION_BUS_ADDRESS"

-- | Connect to the bus specified in the environment variable
-- @DBUS_STARTER_ADDRESS@, which must be set.

getStarterBus :: IO (C.Connection, T.BusName)
getStarterBus = getBus' $ getEnv "DBUS_STARTER_ADDRESS"

getBus' :: IO String -> IO (C.Connection, T.BusName)
getBus' io = do
        addr <- fmap TL.pack io
        case A.mkAddresses addr of
                Just [x] -> getBus Auth.realUserID x
                Just  xs -> getFirstBus [(Auth.realUserID,x) | x <- xs]
                _        -> E.throwIO $ C.InvalidAddress addr

hello :: M.MethodCall
hello = M.MethodCall dbusPath
        "Hello"
        (Just dbusInterface)
        (Just dbusName)
        Set.empty
        []

sendHello :: C.Connection -> IO T.BusName
sendHello c = do
        serial <- fromRight `fmap` C.send c return hello
        reply <- waitForReply c serial
        let name = case M.methodReturnBody reply of
                (x:_) -> T.fromVariant x
                _     -> Nothing
        
        when (isNothing name) $
                E.throwIO $ E.AssertionFailed "Invalid response to Hello()"
        
        return . fromJust $ name

waitForReply :: C.Connection -> M.Serial -> IO M.MethodReturn
waitForReply c serial = do
        received <- C.receive c
        msg <- case received of
                Right x -> return x
                Left _  -> E.throwIO $ E.AssertionFailed "Invalid response to Hello()"
        case msg of
                (M.ReceivedMethodReturn _ _ reply) ->
                        if M.methodReturnSerial reply == serial
                                then return reply
                                else waitForReply c serial
                _ -> waitForReply c serial