module Spread.Client.Connection
(
Connection,privateGroup
,Conf(..),defaultConf
,AuthName
,mkAuthName,authname
,AuthMethod
,connect,disconnect,startReceive,stopReceive,getDupedChan
,join,leave,send
) where
import Network
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString as Bs
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as L
import Foreign
import System.IO.Error
import Control.Concurrent hiding (Chan)
import Data.Bits
import System.IO
import Data.Binary.Put
import Spread.Client.Message
import Network
import Control.Applicative
import Control.Exception
import Control.Monad hiding (join)
import Control.Concurrent.Chan.Closeable as C
import Data.Maybe
import Spread.Constants
data Connection = C {privateGroup :: !PrivateGroup
, chan :: !(Chan W Message), getHandle :: !Handle, thread :: !ThreadId, sync :: !(MVar ()) }
data Conf = Conf { address :: !(Maybe HostName)
, port :: !(Maybe PortNumber)
, desiredName :: !PrivateName
, priority :: !Bool
, groupMembership :: !Bool
, authMethods :: ![AuthMethod]
}
defaultConf :: Conf
defaultConf = Conf Nothing Nothing (mkPrivateName (B.pack "user")) False True []
newtype AuthName = AN ByteString deriving Eq
nullAuthMethod :: (AuthName, t -> IO Bool)
nullAuthMethod = (mkAuthName (B.pack "NULL"), \_ -> return True :: IO Bool)
mkAuthName :: ByteString -> AuthName
mkAuthName = AN . B.take mAX_AUTH_NAME
authname :: AuthName -> ByteString
authname (AN s) = s
type AuthMethod = (AuthName,(Handle -> IO Bool))
connect :: Conf -> IO (Chan R Message,Connection)
connect c =
let addr = fromMaybe "localhost" . address $ c
port' = fromMaybe dEFAULT_SPREAD_PORT . port $ c
(authnames,authmethods) = unzip $ let l = authMethods c in if null l then [nullAuthMethod] else l
in bracketOnError (withSocketsDo $ connectTo addr (PortNumber port')) hClose $ \h -> do
let hget = Bs.hGet h
writePut (mkConnectMsg (desiredName c) (priority c) (groupMembership c)) h
authlistlen <- hGetByte h
checkLen authlistlen
authlist <- map AN . B.split ' ' <$> hget authlistlen
checkAuthNames authnames authlist
writePut (putAuthNames authnames) h
results <- mapM ($h) authmethods
checkAuths results
checkAccepted =<< hGetByte h
[major,minor,patch,glen] <- map fromIntegral . Bs.unpack <$> hget 4
checkVersion major minor patch
mprvg <- mkPrivateGroup <$> hget glen
makeConnection mprvg h
where checkLen l = when ( l > (mAX_AUTH_NAME * mAX_AUTH_METHODS)) $ fail $ "connect: illegal value in authlistlen " ++ show l
checkAuthNames nms list = unless (and $ map (`elem` list) nms) $ fail "connect: chosen authentication method is not permitted by daemon"
checkAuths r = unless (and r) $ fail "connect: authentication of connection failed"
checkAccepted v = unless (v == aCCEPT_SESSION) $ fail "session rejected"
checkVersion maj min p = let val = maj*10000+min*100+p
in do when (val < 30100) $ fail "old spread version, not supported"
when (val < 30800 && priority c) $ fail "old spread version, priority not supported."
makeConnection prvg h = do (r,w) <- C.newChan
m <- newEmptyMVar
tid <- forkIO $ receiver w m h prvg
return $ (r,C prvg w h tid m)
receiver :: Chan W Message -> MVar () -> Handle -> PrivateGroup -> IO ()
receiver c m h prvg = let
recv = receive_internal h prvg
putc = C.writeChan c
p = readMVar m
loop = p >> recv >>= putc >> loop
in handle (\_ -> return ()) $ loop `finally` (hClose h >> C.closeChan c)
hGetByte :: Handle -> IO Int
hGetByte h = alloca $ \p ->
do n <- hGetBuf h p 1;
if n /= 1
then ioError $ mkIOError eofErrorType "hGetByte" (Just h) Nothing
else fromIntegral <$> (peek p :: IO Word8)
putAuthNames :: [AuthName] -> Put
putAuthNames xs = mapM_ (\(AN s) -> putPadded mAX_AUTH_NAME s) . take mAX_AUTH_METHODS $ (xs ++ repeat (AN Bs.empty))
mkConnectMsg :: PrivateName -> Bool -> Bool -> Put
mkConnectMsg p priority gmemb = do
mapM_ putWord8 [sP_MAJOR_VERSION,sP_MINOR_VERSION,sP_PATCH_VERSION]
let setIf b n = if b then (.|. n) else id
putWord8 (setIf gmemb 0x01 . setIf priority 0x10 $ 0)
let pn = privateName p
putWord8 . fromIntegral . B.length $ pn
putByteString pn
writePut p h = L.hPut h (runPut p) >> hFlush h
disconnect :: Connection -> IO ()
disconnect c = let p = privateGroup c in sendInternal (Kill p) c
getDupedChan :: Connection -> IO (Chan R Message)
getDupedChan c = C.forkChan (chan c)
startReceive :: Connection -> IO Bool
startReceive c = tryPutMVar (sync c) ()
stopReceive :: Connection -> IO Bool
stopReceive c = maybe False (const True) <$> tryTakeMVar (sync c)
join :: Group -> Connection -> IO ()
join g = sendInternal (Joining g)
leave :: Group -> Connection -> IO ()
leave g = sendInternal (Leaving g)
send :: OutMsg -> Connection -> IO ()
send = sendInternal
sendInternal m c = let p = privateGroup c in multicast_internal p m (getHandle c) >> return ()