{-
  Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
  
  This program is free software: you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation, either version 3 of the License, or
  any later version.
  
  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
  
  You should have received a copy of the GNU General Public License
  along with this program.  If not, see <http://www.gnu.org/licenses/>.
-}

{-# LANGUAGE OverloadedStrings #-}

{-# LANGUAGE DeriveDataTypeable #-}
module DBus.Connection
        (   Connection
          , connectionAddress
          , connectionUUID

          , ConnectionError (..)

          , connect
          , connectFirst

          , connectionClose

          , send

          , receive

        ) where
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL

import qualified Control.Concurrent as C
import qualified DBus.Address as A
import qualified DBus.Message as M
import qualified DBus.UUID as UUID

import qualified Data.ByteString.Lazy as L
import Data.Word (Word32)

import qualified Network as N
import qualified Data.Map as Map

import qualified Network.Socket as NS

import qualified Text.Parsec as P
import Control.Monad (unless)
import Data.Binary.Get (runGet, getWord16host)
import Data.Binary.Put (runPut, putWord16be)

import qualified System.IO as I

import qualified Control.Exception as E
import Data.Typeable (Typeable)

import qualified DBus.Authentication as Auth

import qualified DBus.Wire as W


data Connection = Connection
        { connectionAddress    :: A.Address
        , connectionTransport  :: Transport
        , connectionSerialMVar :: C.MVar M.Serial
        , connectionReadMVar   :: C.MVar ()
        , connectionUUID       :: UUID.UUID
        }

instance Show Connection where
        showsPrec d con = showParen (d > 10) strCon where
                addr = A.strAddress $ connectionAddress con
                strCon = s "<Connection " . shows addr . s ">"
                s = showString

-- | A 'Transport' is anything which can send and receive bytestrings,
-- typically via a socket.

data Transport = Transport
        { transportSend :: L.ByteString -> IO ()
        , transportRecv :: Word32 -> IO L.ByteString
        , transportClose :: IO ()
        }

connectTransport :: A.Address -> IO Transport
connectTransport a = transport' (A.addressMethod a) a where
        transport' "unix" = unix
        transport' "tcp"  = tcp
        transport' _      = E.throwIO . UnknownMethod

unix :: A.Address -> IO Transport
unix a = port >>= N.connectTo "localhost" >>= handleTransport where
        params = A.addressParameters a
        path = Map.lookup "path" params
        abstract = Map.lookup "abstract" params
        
        tooMany = "Only one of `path' or `abstract' may be specified for the\
                  \ `unix' transport."
        tooFew = "One of `path' or `abstract' must be specified for the\
                 \ `unix' transport."
        
        port = fmap N.UnixSocket path'
        path' = case (path, abstract) of
                (Just _, Just _) -> E.throwIO $ BadParameters a tooMany
                (Nothing, Nothing) -> E.throwIO $ BadParameters a tooFew
                (Just x, Nothing) -> return $ TL.unpack x
                (Nothing, Just x) -> return $ '\x00' : TL.unpack x

tcp :: A.Address -> IO Transport
tcp a = openHandle >>= handleTransport where
        params = A.addressParameters a
        openHandle = do
                port <- getPort
                family <- getFamily
                addresses <- getAddresses family
                socket <- openSocket port addresses
                NS.socketToHandle socket I.ReadWriteMode

        hostname = maybe "localhost" TL.unpack $ Map.lookup "host" params

        unknownFamily x = TL.concat ["Unknown socket family for TCP transport: ", x]
        getFamily = case Map.lookup "family" params of
                Just "ipv4" -> return NS.AF_INET
                Just "ipv6" -> return NS.AF_INET6
                Nothing     -> return NS.AF_UNSPEC
                Just x      -> E.throwIO $ BadParameters a $ unknownFamily x

        missingPort = "TCP transport requires the ``port'' parameter."
        badPort x = TL.concat ["Invalid socket port for TCP transport: ", x]
        getPort = case Map.lookup "port" params of
                Nothing -> E.throwIO $ BadParameters a missingPort
                Just x -> case P.parse parseWord16 "" (TL.unpack x) of
                        Right x' -> return $ NS.PortNum x'
                        Left  _  -> E.throwIO $ BadParameters a $ badPort x

        parseWord16 = do
                chars <- P.many1 P.digit
                P.eof
                let value = read chars :: Integer
                unless (value > 0 && value <= 65535) $
                        P.parserFail "bad port" >> return ()
                let word = fromIntegral value
                return $ runGet getWord16host (runPut (putWord16be word))

        getAddresses family = do
                let hints = NS.defaultHints
                        { NS.addrFlags = [NS.AI_ADDRCONFIG]
                        , NS.addrFamily = family
                        , NS.addrSocketType = NS.Stream
                        }
                NS.getAddrInfo (Just hints) (Just hostname) Nothing

        setPort port (NS.SockAddrInet  _ x)     = NS.SockAddrInet port x
        setPort port (NS.SockAddrInet6 _ x y z) = NS.SockAddrInet6 port x y z
        setPort _    addr                       = addr

        openSocket _ [] = E.throwIO $ NoWorkingAddress [a]
        openSocket port (addr:addrs) = E.catch (openSocket' port addr) $
                \(E.SomeException _) -> openSocket port addrs
        openSocket' port addr = do
                sock <- NS.socket (NS.addrFamily addr)
                                  (NS.addrSocketType addr)
                                  (NS.addrProtocol addr)
                NS.connect sock . setPort port . NS.addrAddress $ addr
                return sock

handleTransport :: I.Handle -> IO Transport
handleTransport h = do
        I.hSetBuffering h I.NoBuffering
        I.hSetBinaryMode h True
        return $ Transport (L.hPut h) (L.hGet h . fromIntegral) (I.hClose h)

data ConnectionError
        = InvalidAddress Text
        | BadParameters A.Address Text
        | UnknownMethod A.Address
        | NoWorkingAddress [A.Address]
        deriving (Show, Typeable)

instance E.Exception ConnectionError

-- | Open a connection to some address, using a given authentication
-- mechanism. If the connection fails, a 'ConnectionError' will be thrown.

connect :: Auth.Mechanism -> A.Address -> IO Connection
connect mechanism a = do
        t <- connectTransport a
        let getByte = L.head `fmap` transportRecv t 1
        uuid <- Auth.authenticate mechanism (transportSend t) getByte
        readLock <- C.newMVar ()
        serialMVar <- C.newMVar M.firstSerial
        return $ Connection a t serialMVar readLock uuid

-- | Try to open a connection to various addresses, returning the first
-- connection which could be successfully opened.

connectFirst :: [(Auth.Mechanism, A.Address)] -> IO Connection
connectFirst orig = connectFirst' orig where
        allAddrs = [a | (_, a) <- orig]
        connectFirst'     [] = E.throwIO $ NoWorkingAddress allAddrs
        connectFirst' ((mech, a):as) = E.catch (connect mech a) $
                \(E.SomeException _) -> connectFirst' as

-- | Close an open connection. Once closed, the 'Connection' is no longer
-- valid and must not be used.

connectionClose :: Connection -> IO ()
connectionClose = transportClose . connectionTransport

-- | Send a single message, with a generated 'M.Serial'. The second parameter
-- exists to prevent race conditions when registering a reply handler; it
-- receives the serial the message /will/ be sent with, before it's actually
-- sent.
--
-- Only one message may be sent at a time; if multiple threads attempt to
-- send messages in parallel, one will block until after the other has
-- finished.

send :: M.Message a => Connection -> (M.Serial -> IO b) -> a
     -> IO (Either W.MarshalError b)
send (Connection _ t mvar _ _) io msg = withSerial mvar $ \serial ->
        case W.marshalMessage W.LittleEndian serial msg of
                Right bytes -> do
                        x <- io serial
                        transportSend t bytes
                        return $ Right x
                Left  err   -> return $ Left err

withSerial :: C.MVar M.Serial -> (M.Serial -> IO a) -> IO a
withSerial m io = E.block $ do
        s <- C.takeMVar m
        let s' = M.nextSerial s
        x <- E.unblock (io s) `E.onException` C.putMVar m s'
        C.putMVar m s'
        return x

-- | Receive the next message from the connection, blocking until one is
-- available.
--
-- Only one message may be received at a time; if multiple threads attempt
-- to receive messages in parallel, one will block until after the other has
-- finished.

receive :: Connection -> IO (Either W.UnmarshalError M.ReceivedMessage)
receive (Connection _ t _ lock _) = C.withMVar lock $ \_ ->
        W.unmarshalMessage $ transportRecv t