module Spread.Client.Connection ( Connection,privateGroup ,Conf(..),defaultConf ,AuthName ,mkAuthName,authname ,AuthMethod ,connect,disconnect,startReceive,stopReceive,getDupedChan ,join,leave,send ) where import Network import qualified Data.ByteString.Char8 as B import qualified Data.ByteString as Bs import Data.ByteString (ByteString) import qualified Data.ByteString.Lazy as L import Foreign import System.IO.Error import Control.Concurrent hiding (Chan) import Data.Bits import System.IO import Data.Binary.Put import Spread.Client.Message import Network import Control.Applicative import Control.Exception import Control.Monad hiding (join) import Control.Concurrent.Chan.Closeable as C import Data.Maybe import Spread.Constants -- | Abstract type representing a connection with a spread server. data Connection = C {privateGroup :: !PrivateGroup -- ^ private name of this connection, useful for p2p messages. , chan :: !(Chan W Message), getHandle :: !Handle, thread :: !ThreadId, sync :: !(MVar ()) } -- | Configuration passed to 'connect' data Conf = Conf { address :: !(Maybe HostName) -- ^ Server address, using localhost if 'Nothing'. , port :: !(Maybe PortNumber) -- ^ Server port, uses the default spread port if 'Nothing'. , desiredName :: !PrivateName -- ^ It will become part of the 'PrivateGroup' of the 'Connection' , priority :: !Bool -- ^ Is this a priority connection? , groupMembership :: !Bool -- ^ Should it receive Membership messages? , authMethods :: ![AuthMethod] -- ^ Authentication methods to use when connecting. } -- | defaulConf = Conf Nothing Nothing (mkPrivateName (B.pack \"user\")) False True [] defaultConf :: Conf defaultConf = Conf Nothing Nothing (mkPrivateName (B.pack "user")) False True [] -- | Name of an authentication method. newtype AuthName = AN ByteString deriving Eq nullAuthMethod :: (AuthName, t -> IO Bool) nullAuthMethod = (mkAuthName (B.pack "NULL"), \_ -> return True :: IO Bool) -- | The 'ByteString' will be truncated to the maximum allowed size. mkAuthName :: ByteString -> AuthName mkAuthName = AN . B.take mAX_AUTH_NAME authname :: AuthName -> ByteString authname (AN s) = s -- | The action should return True if the authentication succeded. type AuthMethod = (AuthName,(Handle -> IO Bool)) -- | Connects to the specified server, will use a \"NULL\" authentication method if the 'authMethods' list is empty. -- A spread server will refuse the connection if another with the same PrivateName is still active. connect :: Conf -> IO (Chan R Message,Connection) connect c = let addr = fromMaybe "localhost" . address $ c port' = fromMaybe dEFAULT_SPREAD_PORT . port $ c (authnames,authmethods) = unzip $ let l = authMethods c in if null l then [nullAuthMethod] else l in bracketOnError (withSocketsDo $ connectTo addr (PortNumber port')) hClose $ \h -> do let hget = Bs.hGet h writePut (mkConnectMsg (desiredName c) (priority c) (groupMembership c)) h authlistlen <- hGetByte h checkLen authlistlen authlist <- map AN . B.split ' ' <$> hget authlistlen checkAuthNames authnames authlist writePut (putAuthNames authnames) h results <- mapM ($h) authmethods checkAuths results checkAccepted =<< hGetByte h [major,minor,patch,glen] <- map fromIntegral . Bs.unpack <$> hget 4 checkVersion major minor patch mprvg <- mkPrivateGroup <$> hget glen makeConnection mprvg h where checkLen l = when ( l > (mAX_AUTH_NAME * mAX_AUTH_METHODS)) $ fail $ "connect: illegal value in authlistlen " ++ show l checkAuthNames nms list = unless (and $ map (`elem` list) nms) $ fail "connect: chosen authentication method is not permitted by daemon" checkAuths r = unless (and r) $ fail "connect: authentication of connection failed" checkAccepted v = unless (v == aCCEPT_SESSION) $ fail "session rejected" checkVersion maj min p = let val = maj*10000+min*100+p in do when (val < 30100) $ fail "old spread version, not supported" when (val < 30800 && priority c) $ fail "old spread version, priority not supported." makeConnection prvg h = do (r,w) <- C.newChan m <- newEmptyMVar tid <- forkIO $ receiver w m h prvg return $ (r,C prvg w h tid m) receiver :: Chan W Message -> MVar () -> Handle -> PrivateGroup -> IO () receiver c m h prvg = let recv = receive_internal h prvg putc = C.writeChan c p = readMVar m loop = p >> recv >>= putc >> loop in handle (\_ -> return ()) $ loop `finally` (hClose h >> C.closeChan c) hGetByte :: Handle -> IO Int hGetByte h = alloca $ \p -> do n <- hGetBuf h p 1; if n /= 1 then ioError $ mkIOError eofErrorType "hGetByte" (Just h) Nothing else fromIntegral <$> (peek p :: IO Word8) putAuthNames :: [AuthName] -> Put putAuthNames xs = mapM_ (\(AN s) -> putPadded mAX_AUTH_NAME s) . take mAX_AUTH_METHODS $ (xs ++ repeat (AN Bs.empty)) mkConnectMsg :: PrivateName -> Bool -> Bool -> Put mkConnectMsg p priority gmemb = do mapM_ putWord8 [sP_MAJOR_VERSION,sP_MINOR_VERSION,sP_PATCH_VERSION] let setIf b n = if b then (.|. n) else id putWord8 (setIf gmemb 0x01 . setIf priority 0x10 $ 0) let pn = privateName p putWord8 . fromIntegral . B.length $ pn putByteString pn writePut p h = L.hPut h (runPut p) >> hFlush h -- | Sends a disconnection message to the server, which will close the connection. disconnect :: Connection -> IO () disconnect c = let p = privateGroup c in sendInternal (Kill p) c -- | Messages received from now on will be available on the returned 'Chan' getDupedChan :: Connection -> IO (Chan R Message) getDupedChan c = C.forkChan (chan c) -- | Start fetching messages from the network, returns True if it was stopped. startReceive :: Connection -> IO Bool startReceive c = tryPutMVar (sync c) () -- | Stop fetching messages from the network (at most one more message can be read) -- , returns True if it was started. stopReceive :: Connection -> IO Bool stopReceive c = maybe False (const True) <$> tryTakeMVar (sync c) -- | Joins a group, the server will send a 'Reg'. join :: Group -> Connection -> IO () join g = sendInternal (Joining g) -- | Leaves a group, the server will send a 'SelfLeave'. leave :: Group -> Connection -> IO () leave g = sendInternal (Leaving g) -- | Send a regular message. send :: OutMsg -> Connection -> IO () send = sendInternal sendInternal m c = let p = privateGroup c in multicast_internal p m (getHandle c) >> return ()