{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE OverloadedStrings #-}

{-# LANGUAGE DeriveDataTypeable #-}
module DBus.Connection
        (   Connection

          , ConnectionError (..)
          , W.MarshalError (..)
          , W.UnmarshalError (..)

          , connect

          , send

          , receive

        ) where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL

import qualified Control.Concurrent as C
import qualified DBus.Address as A
import qualified DBus.Message.Internal as M

import qualified Data.ByteString.Lazy as L
import Data.Word (Word32)

import qualified Network as N
import qualified Data.Map as Map

import qualified System.IO as I

import qualified Control.Exception as E
import Data.Typeable (Typeable)

import Data.Text.Lazy.Encoding (decodeUtf8, encodeUtf8)

import System.Posix.User (getRealUserID)
import Data.Char (ord)
import Text.Printf (printf)

import Data.List (isPrefixOf)

import qualified DBus.Wire as W


data Connection = Connection A.Address Transport (C.MVar M.Serial)

instance Show Connection where
        showsPrec d (Connection a _ _) = showParen (d > 10) $
                showString' ["<connection ", show $ A.strAddress a, ">"] where
                showString' = foldr (.) id . map showString

data Transport = Transport
        { transportSend :: L.ByteString -> IO ()
        , transportRecv :: Word32 -> IO L.ByteString
        }

connectTransport :: A.Address -> IO Transport
connectTransport a = transport' (A.addressMethod a) a where
        transport' "unix" = unix
        transport' _      = unknownTransport

unix :: A.Address -> IO Transport
unix a = handleTransport . N.connectTo "localhost" =<< port where
        params = A.addressParameters a
        path = Map.lookup "path" params
        abstract = Map.lookup "abstract" params
        
        tooMany = "Only one of `path' or `abstract' may be specified for the\
                  \ `unix' method."
        tooFew = "One of `path' or `abstract' must be specified for the\
                 \ `unix' transport."
        
        port = fmap N.UnixSocket path'
        path' = case (path, abstract) of
                (Just _, Just _) -> E.throwIO $ BadParameters a tooMany
                (Nothing, Nothing) -> E.throwIO $ BadParameters a tooFew
                (Just x, Nothing) -> return $ TL.unpack x
                (Nothing, Just x) -> return $ '\x00' : TL.unpack x

handleTransport :: IO I.Handle -> IO Transport
handleTransport io = do
        h <- io
        I.hSetBuffering h I.NoBuffering
        I.hSetBinaryMode h True
        return $ Transport (L.hPut h) (L.hGet h . fromIntegral)

unknownTransport :: A.Address -> IO Transport
unknownTransport = E.throwIO . UnknownMethod

data ConnectionError
        = InvalidAddress Text
        | BadParameters A.Address Text
        | UnknownMethod A.Address
        | NoWorkingAddress [A.Address]
        deriving (Show, Typeable)

instance E.Exception ConnectionError

connect :: A.Address -> IO Connection
connect a = do
        t <- connectTransport a
        let putS = transportSend t . encodeUtf8 . TL.pack
        let getS = fmap (TL.unpack . decodeUtf8) . transportRecv t
        authenticate putS getS
        serialMVar <- C.newMVar M.firstSerial
        return $ Connection a t serialMVar

authenticate :: (String -> IO ()) -> (Word32 -> IO String)
                -> IO ()
authenticate put get = do
        put "\x00"

        uid <- getRealUserID
        let authToken = concatMap (printf "%02X" . ord) (show uid)
        put $ "AUTH EXTERNAL " ++ authToken ++ "\r\n"

        response <- readUntil '\n' get
        if "OK" `isPrefixOf` response
                then put "BEGIN\r\n"
                else do
                        putStrLn $ "response = " ++ show response
                        error "Server rejected authentication token."

readUntil :: Monad m => Char -> (Word32 -> m String) -> m String
readUntil = readUntil' "" where
        readUntil' xs c f = do
                [x] <- f 1
                let xs' = xs ++ [x]
                if x == c
                        then return xs'
                        else readUntil' xs' c f

send :: M.Message a => Connection -> (M.Serial -> IO b) -> a
     -> IO (Either W.MarshalError b)
send (Connection _ t mvar) io msg = withSerial mvar $ \serial ->
        case W.marshalMessage W.LittleEndian serial msg of
                Right bytes -> do
                        x <- io serial
                        transportSend t bytes
                        return $ Right x
                Left  err   -> return $ Left err

withSerial :: C.MVar M.Serial -> (M.Serial -> IO a) -> IO a
withSerial m io = E.block $ do
        s <- C.takeMVar m
        let s' = M.nextSerial s
        x <- E.unblock (io s) `E.onException` C.putMVar m s'
        C.putMVar m s'
        return x

receive :: Connection -> IO (Either W.UnmarshalError M.ReceivedMessage)
receive (Connection _ t _) = W.unmarshalMessage $ transportRecv t