{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE OverloadedStrings #-}

{-# LANGUAGE DeriveDataTypeable #-}
module DBus.Connection
        (   Connection

          , ConnectionError (..)

          , connect
          , connectFirst

          , send

          , receive

        ) where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL

import qualified Control.Concurrent as C
import qualified DBus.Address as A
import qualified DBus.Message.Internal as M

import qualified Data.ByteString.Lazy as L
import Data.Word (Word32)

import qualified Network as N
import qualified Data.Map as Map

import qualified Network.Socket as NS

import qualified Text.Parsec as P
import Control.Monad (unless)
import Data.Binary.Get (runGet, getWord16host)
import Data.Binary.Put (runPut, putWord16be)

import qualified System.IO as I

import qualified Control.Exception as E
import Data.Typeable (Typeable)

import Data.Text.Lazy.Encoding (decodeUtf8, encodeUtf8)

import System.Posix.User (getRealUserID)
import Data.Char (ord)
import Text.Printf (printf)

import Data.List (isPrefixOf)

import qualified DBus.Wire as W


data Connection = Connection A.Address Transport (C.MVar M.Serial) (C.MVar ())

instance Show Connection where
        showsPrec d (Connection a _ _ _) = showParen (d > 10) $
                showString' ["<connection ", show $ A.strAddress a, ">"] where
                showString' = foldr (.) id . map showString

data Transport = Transport
        { transportSend :: L.ByteString -> IO ()
        , transportRecv :: Word32 -> IO L.ByteString
        }

connectTransport :: A.Address -> IO Transport
connectTransport a = transport' (A.addressMethod a) a where
        transport' "unix" = unix
        transport' "tcp"  = tcp
        transport' _      = E.throwIO . UnknownMethod

unix :: A.Address -> IO Transport
unix a = port >>= N.connectTo "localhost" >>= handleTransport where
        params = A.addressParameters a
        path = Map.lookup "path" params
        abstract = Map.lookup "abstract" params
        
        tooMany = "Only one of `path' or `abstract' may be specified for the\
                  \ `unix' transport."
        tooFew = "One of `path' or `abstract' must be specified for the\
                 \ `unix' transport."
        
        port = fmap N.UnixSocket path'
        path' = case (path, abstract) of
                (Just _, Just _) -> E.throwIO $ BadParameters a tooMany
                (Nothing, Nothing) -> E.throwIO $ BadParameters a tooFew
                (Just x, Nothing) -> return $ TL.unpack x
                (Nothing, Just x) -> return $ '\x00' : TL.unpack x

tcp :: A.Address -> IO Transport
tcp a = openHandle >>= handleTransport where
        params = A.addressParameters a
        openHandle = do
                port <- getPort
                family <- getFamily
                addresses <- getAddresses family
                socket <- openSocket port addresses
                NS.socketToHandle socket I.ReadWriteMode

        hostname = maybe "localhost" TL.unpack $ Map.lookup "host" params

        unknownFamily x = TL.concat ["Unknown socket family for TCP transport: ", x]
        getFamily = case Map.lookup "family" params of
                Just "ipv4" -> return NS.AF_INET
                Just "ipv6" -> return NS.AF_INET6
                Nothing     -> return NS.AF_UNSPEC
                Just x      -> E.throwIO $ BadParameters a $ unknownFamily x

        missingPort = "TCP transport requires the ``port'' parameter."
        badPort x = TL.concat ["Invalid socket port for TCP transport: ", x]
        getPort = case Map.lookup "port" params of
                Nothing -> E.throwIO $ BadParameters a missingPort
                Just x -> case P.parse parseWord16 "" (TL.unpack x) of
                        Right x' -> return $ NS.PortNum x'
                        Left  _  -> E.throwIO $ BadParameters a $ badPort x

        parseWord16 = do
                chars <- P.many1 P.digit
                P.eof
                let value = read chars :: Integer
                unless (value > 0 && value <= 65535) $
                        P.parserFail "bad port" >> return ()
                let word = fromIntegral value
                return $ runGet getWord16host (runPut (putWord16be word))

        getAddresses family = do
                let hints = NS.defaultHints
                        { NS.addrFlags = [NS.AI_ADDRCONFIG]
                        , NS.addrFamily = family
                        , NS.addrSocketType = NS.Stream
                        }
                NS.getAddrInfo (Just hints) (Just hostname) Nothing

        setPort port (NS.SockAddrInet  _ x)     = NS.SockAddrInet port x
        setPort port (NS.SockAddrInet6 _ x y z) = NS.SockAddrInet6 port x y z
        setPort _    addr                       = addr

        openSocket _ [] = E.throwIO $ NoWorkingAddress [a]
        openSocket port (addr:addrs) = E.catch (openSocket' port addr) $
                \(E.SomeException _) -> openSocket port addrs
        openSocket' port addr = do
                sock <- NS.socket (NS.addrFamily addr)
                                  (NS.addrSocketType addr)
                                  (NS.addrProtocol addr)
                NS.connect sock . setPort port . NS.addrAddress $ addr
                return sock

handleTransport :: I.Handle -> IO Transport
handleTransport h = do
        I.hSetBuffering h I.NoBuffering
        I.hSetBinaryMode h True
        return $ Transport (L.hPut h) (L.hGet h . fromIntegral)

data ConnectionError
        = InvalidAddress Text
        | BadParameters A.Address Text
        | UnknownMethod A.Address
        | NoWorkingAddress [A.Address]
        deriving (Show, Typeable)

instance E.Exception ConnectionError

connect :: A.Address -> IO Connection
connect a = do
        t <- connectTransport a
        let putS = transportSend t . encodeUtf8 . TL.pack
        let getS = fmap (TL.unpack . decodeUtf8) . transportRecv t
        authenticate putS getS
        readLock <- C.newMVar ()
        serialMVar <- C.newMVar M.firstSerial
        return $ Connection a t serialMVar readLock

connectFirst :: [A.Address] -> IO Connection
connectFirst orig = connectFirst' orig where
        connectFirst'     [] = E.throwIO $ NoWorkingAddress orig
        connectFirst' (a:as) = E.catch (connect a) $
                \(E.SomeException _) -> connectFirst' as

authenticate :: (String -> IO ()) -> (Word32 -> IO String)
                -> IO ()
authenticate put get = do
        put "\x00"

        uid <- getRealUserID
        let authToken = concatMap (printf "%02X" . ord) (show uid)
        put $ "AUTH EXTERNAL " ++ authToken ++ "\r\n"

        response <- readUntil '\n' get
        if "OK" `isPrefixOf` response
                then put "BEGIN\r\n"
                else do
                        putStrLn $ "response = " ++ show response
                        error "Server rejected authentication token."

readUntil :: Monad m => Char -> (Word32 -> m String) -> m String
readUntil = readUntil' "" where
        readUntil' xs c f = do
                [x] <- f 1
                let xs' = xs ++ [x]
                if x == c
                        then return xs'
                        else readUntil' xs' c f

send :: M.Message a => Connection -> (M.Serial -> IO b) -> a
     -> IO (Either W.MarshalError b)
send (Connection _ t mvar _) io msg = withSerial mvar $ \serial ->
        case W.marshalMessage W.LittleEndian serial msg of
                Right bytes -> do
                        x <- io serial
                        transportSend t bytes
                        return $ Right x
                Left  err   -> return $ Left err

withSerial :: C.MVar M.Serial -> (M.Serial -> IO a) -> IO a
withSerial m io = E.block $ do
        s <- C.takeMVar m
        let s' = M.nextSerial s
        x <- E.unblock (io s) `E.onException` C.putMVar m s'
        C.putMVar m s'
        return x

receive :: Connection -> IO (Either W.UnmarshalError M.ReceivedMessage)
receive (Connection _ t _ lock) = C.withMVar lock $ \_ ->
        W.unmarshalMessage $ transportRecv t