module Network.DBus.Connection (
    Connection(),
    ConnectionAddress(..),
    Handler,
    MatchClause(..),
    MatchRule,
    connectToBus,
    getSessionBusAddress,
    getSystemBusAddress,
    parseAddress,
    sendMessage,
    sendAndWait,
    addHandler,
    removeHandler,
    uniqueName
) where

import Control.Concurrent
import Control.Monad (forever, forM_, when)
import Data.List (intersperse)
import Data.Maybe (fromJust)
import Data.Typeable (cast)
import Data.Word
import Network.Socket
import Numeric (showHex)
import System.Environment (getEnvironment)
import System.IO
import System.Posix.User (getEffectiveUserID)

import Network.DBus.Message
import Network.DBus.Value

data MatchClause
    = MatchType MessageType
    | MatchSender DString
    | MatchInterface DString
    | MatchMember DString
    | MatchPath ObjectPath
    | MatchDestination DString
    | MatchArg Int DString
    deriving Show

type MatchRule = [MatchClause]

matchString :: MatchRule -> String
matchString = concat . intersperse "," . map toStr
    where toStr (MatchType t)        = "type='" ++ typeToStr t ++ "'"
          toStr (MatchSender s)      = "sender='" ++ getString s ++ "'"
          toStr (MatchInterface s)   = "interface='" ++ getString s ++ "'"
          toStr (MatchMember s)      = "member='" ++ getString s ++ "'"
          toStr (MatchPath o)        = "path='" ++ getPath o ++ "'"
          toStr (MatchDestination s) = "destination='" ++ getString s ++ "'"
          toStr (MatchArg n s)       = "arg" ++ show n ++ "='" ++ getString s ++ "'"

          typeToStr Signal       = "signal"
          typeToStr MethodCall   = "method_call"
          typeToStr MethodReturn = "method_return"
          typeToStr Error        = "error"

match :: MatchRule -> Message -> Bool
match cs msg = all (flip matchClause msg) cs
    where Nothing  =?= _ = False
          (Just x) =?= y = x == y

          []     # _ = Nothing
          (x:_)  # 0 = Just x
          (_:xs) # n = xs # (n - 1)

          matchClause :: MatchClause -> Message -> Bool
          matchClause (MatchType t)        m = (== t)  . mType        $ m
          matchClause (MatchSender s)      m = (=?= s) . mSender      $ m
          matchClause (MatchInterface s)   m = (=?= s) . mInterface   $ m
          matchClause (MatchMember s)      m = (=?= s) . mMember      $ m
          matchClause (MatchPath o)        m = (=?= o) . mPath        $ m
          matchClause (MatchDestination s) m = (=?= s) . mDestination $ m
          matchClause (MatchArg n s)       m = case mBody m # n of
                                                   Nothing -> False
                                                   Just x -> (Variant s) == x

type Handler = Message -> IO ()

data Connection = Connection {
    cSerial :: MVar Word32,
    cSock :: Socket,
    cHandle :: Handle,
    cUniqueName :: DString,
    cThread :: ThreadId,
    cPendingCalls :: MVar [(Word32, MVar Message)],
    cHandlerSerial :: MVar Int,
    cHandlers :: MVar [(Int, Maybe MatchRule, Handler)] }

data ConnectionAddress =
    Unix { addrPath :: String, addrGuid :: Maybe String } |
    UnixAbstract { addrPath :: String, addrGuid :: Maybe String }
    deriving Show

uniqueName :: Connection -> DString
uniqueName = cUniqueName

getSessionBusAddress :: IO (Maybe ConnectionAddress)
getSessionBusAddress = do
    e <- getEnvironment
    return $ lookup "DBUS_SESSION_BUS_ADDRESS" e >>= parseAddress

getSystemBusAddress :: IO (Maybe ConnectionAddress)
getSystemBusAddress = do
    e <- getEnvironment
    return $ case lookup "DBUS_SYSTEM_BUS_ADDRESS" e of
        Just addr -> parseAddress addr
        Nothing -> Just $ Unix "/var/run/dbus/system_bus_socket" Nothing

splitOn :: (Eq a) => a -> [a] -> [[a]]
splitOn _ [] = []
splitOn e xs = let (before, after) = break (== e) xs
               in before : case after of
                                [] -> []
                                (_:xs') -> splitOn e xs'

-- Example:
--   unix:abstract=/tmp/dbus-Gxt8Av4CSA,guid=7515a79962a02df7ca39e3b049982f9e

parseAddress :: String -> Maybe ConnectionAddress
parseAddress str = do
    (addrType, pairsStr) <-
        case break (==':') str of
            (x, ':':y) -> Just (x, y)
            _ -> Nothing
    pairs <- parsePairs pairsStr
    let guid = lookup "guid" pairs
    case addrType of
        "unix" -> do
            case lookup "abstract" pairs of
                Just p -> Just $ UnixAbstract p guid
                Nothing -> case lookup "path" pairs of
                               Just p -> Just $ Unix p guid
                               Nothing -> fail "malformed unix address"
        _ -> fail "unknown address type"
    where parsePairs = mapM parsePair . splitOn ','
          parsePair xs = case break (== '=') xs of
                             (name, ('=':value)) -> return (name, value)
                             _ -> fail "bad pair"

hello :: Message
hello = methodCall
    (fromJust $ mkDString "org.freedesktop.DBus")
    (fromJust $ mkDString "Hello")
    (fromJust $ mkDString "org.freedesktop.DBus")
    (fromJust $ mkObjectPath "/org/freedesktop/DBus")
    []

sendMessage :: Connection -> Message -> IO Word32
sendMessage conn msg = modifyMVar (cSerial conn) $ \ s -> do
    writeMessage (cHandle conn) (msg { mSerial = s })
    hFlush (cHandle conn)
    return (s + 1, s)

sendExpectingReply :: Connection -> Message -> IO (MVar Message)
sendExpectingReply conn msg = do
    pc <- takeMVar (cPendingCalls conn)
    mvar <- newEmptyMVar
    serial <- sendMessage conn msg
    putMVar (cPendingCalls conn) $ (serial, mvar) : pc
    return mvar

sendAndWait :: Connection -> Message -> IO Message
sendAndWait conn msg = sendExpectingReply conn msg >>= takeMVar

receiveMessage :: Connection -> IO Message
receiveMessage conn = readMessage (cHandle conn)

hexEncode :: String -> String
hexEncode = concatMap (encodeByte . fromEnum)
    where encodeByte n
              | n < 16    = '0' : showHex n ""
              | otherwise =       showHex n ""

authenticate :: Handle -> IO ()
authenticate handle = do
    hPutChar handle '\0'
    -- XXX: do this properly
    uid <- getEffectiveUserID
    let auth = concat ["AUTH EXTERNAL ", hexEncode . show $ uid, "\r\n"]
    hPutStr handle auth
    hFlush handle
    response <- hGetLine handle
    case words response of
        ["OK", _] -> hPutStr handle "BEGIN\r\n" >> hFlush handle
        _ -> fail "authentication failed"

lookupRemoveBy :: (a -> Maybe b) -> [a] -> Maybe (b, [a])
lookupRemoveBy f l = loop l id
    where loop [] _ = Nothing
          loop (x:xs) g = case f x of
                              Nothing -> loop xs (g . (x:))
                              Just y -> Just (y, g xs)

addHandler :: Connection -> Maybe MatchRule -> Handler -> IO Int
addHandler conn rule f = do
    hid <- modifyMVar (cHandlerSerial conn) $ \s -> return (s + 1, s)
    modifyMVar_ (cHandlers conn) $ \handlers ->
        return $ (hid, rule, f) : handlers
    case rule of
        Nothing -> return ()
        -- XXX: catch errors?
        -- XXX: wait for reply?
        Just rule' -> do
            sendMessage conn $ addMatch (mkDString0 $ matchString rule')
            return ()
    return hid

    where addMatch :: DString -> Message
          addMatch = methodCall
              (fromJust $ mkDString "org.freedesktop.DBus")
              (fromJust $ mkDString "AddMatch")
              (fromJust $ mkDString "org.freedesktop.DBus")
              (fromJust $ mkObjectPath "/org/freedesktop/DBus")
              . (:[]) . Variant

removeHandler :: Connection -> Int -> IO ()
removeHandler conn hid = do
    rule <- modifyMVar (cHandlers conn) $ \handlers ->
        case lookupRemoveBy (\(hid', rule, _) ->
                if hid' == hid then Just rule else Nothing) handlers of
             Nothing -> return (handlers, Nothing)
             Just (rule, handlers') -> return (handlers', rule)
    case rule of
        Nothing -> return ()
        -- XXX: catch errors?
        -- XXX: wait for reply?
        Just rule' -> do
            sendMessage conn $ removeMatch (matchString rule')
            return ()

    where removeMatch :: String -> Message
          removeMatch = methodCall
              (fromJust $ mkDString "org.freedesktop.DBus")
              (fromJust $ mkDString "RemoveMatch")
              (fromJust $ mkDString "org.freedesktop.DBus")
              (fromJust $ mkObjectPath "/org/freedesktop/DBus")
              . (:[]) . Variant

-- XXX: perhaps:
-- withHandler :: Connection -> Maybe MatchRule -> Handler -> IO a -> IO a

lookupRemove :: Eq a => a -> [(a, b)] -> Maybe (b, [(a, b)])
lookupRemove x = lookupRemoveBy (\(a, b) -> if a == x then Just b else Nothing)

receiveLoop :: Connection -> IO ()
receiveLoop conn = forever $ receiveMessage conn >>= \msg -> do
    case mReplySerial msg of
        Nothing -> return ()
        Just serial -> modifyMVar_ (cPendingCalls conn) $ \pc -> do
            case lookupRemove serial pc of
                -- XXX: catch BlockedOnDeadMVar
                Just (mvar, pc') -> putMVar mvar msg >> return pc'
                Nothing -> return pc

    ms <- readMVar (cHandlers conn)
    -- XXX: this is O(n) in the number of handlers
    forM_ ms $ \(_, rule', handler) ->
        case rule' of
            Nothing -> handler msg >> return ()
            Just rule -> when (match rule msg) (handler msg >> return ())

connectToBus :: ConnectionAddress -> IO Connection
connectToBus dbusAddr = do
    sock <- socket AF_UNIX Stream 0
    let addr = sockAddr dbusAddr
    Network.Socket.connect sock addr
    handle <- socketToHandle sock ReadWriteMode

    authenticate handle

    writeMessage handle $ hello { mSerial = 1 }
    hFlush handle
    reply <- readMessage handle
    name <- case mBody reply of
        [Variant name] -> return $ fromJust (cast name)
        _ -> fail "Hello() call failed during connection"

    serial <- newMVar 2
    pendingCalls <- newMVar []
    handlerSerial <- newMVar 0
    handlers <- newMVar []
    let conn = Connection {
        cSerial = serial,
        cSock = sock,
        cHandle = handle,
        cUniqueName = name,
        cThread = undefined,
        cPendingCalls = pendingCalls,
        cHandlerSerial = handlerSerial,
        cHandlers = handlers
        }
    threadId <- forkIO $ receiveLoop conn
    return $ conn { cThread = threadId }

    where

    sockAddr :: ConnectionAddress -> SockAddr
    sockAddr Unix         { addrPath = path } = SockAddrUnix $ path
    sockAddr UnixAbstract { addrPath = path } = SockAddrUnix $ '\0' : path