module Spread.Client.Connection
    (
     Connection,privateGroup
    ,Conf(..),defaultConf
    ,AuthName
    ,mkAuthName,authname
    ,AuthMethod
    ,connect,disconnect,startReceive,stopReceive,getDupedChan
    ,join,leave,send
    ) where
import Network
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString as Bs
import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as L
import Foreign
import System.IO.Error
import Control.Concurrent hiding (Chan)
import Data.Bits
import System.IO
import Data.Binary.Put
import Spread.Client.Message
import Network
import Control.Applicative
import Control.Exception
import Control.Monad hiding (join)
import Control.Concurrent.Chan.Closeable as C
import Data.Maybe
import Spread.Constants
-- | Abstract type representing a connection with a spread server.
data Connection = C {privateGroup :: !PrivateGroup -- ^ private name of this connection, useful for p2p messages.
                    , chan :: !(Chan W Message), getHandle :: !Handle, thread :: !ThreadId, sync :: !(MVar ()) }
-- | Configuration passed to 'connect' 
data Conf = Conf { address :: !(Maybe HostName)    -- ^ Server address, using localhost if 'Nothing'.
                 , port :: !(Maybe PortNumber)     -- ^ Server port, uses the default spread port if 'Nothing'.
                 , desiredName :: !PrivateName     -- ^ It will become part of the 'PrivateGroup' of the 'Connection'
                 , priority :: !Bool               -- ^ Is this a priority connection?
                 , groupMembership :: !Bool        -- ^ Should it receive Membership messages?
                 , authMethods :: ![AuthMethod]    -- ^ Authentication methods to use when connecting.
                 }
-- | defaulConf = Conf Nothing Nothing (mkPrivateName (B.pack \"user\")) False True []
defaultConf :: Conf
defaultConf = Conf Nothing Nothing (mkPrivateName (B.pack "user")) False True []
-- | Name of an authentication method.
newtype AuthName = AN ByteString deriving Eq
nullAuthMethod :: (AuthName, t -> IO Bool)
nullAuthMethod = (mkAuthName (B.pack "NULL"), \_ -> return True :: IO Bool)
-- | The 'ByteString' will be truncated to the maximum allowed size.
mkAuthName :: ByteString -> AuthName
mkAuthName = AN . B.take mAX_AUTH_NAME
authname :: AuthName -> ByteString
authname (AN s) = s
-- | The action should return True if the authentication succeded.
type AuthMethod = (AuthName,(Handle -> IO Bool))

-- | Connects to the specified server, will use a \"NULL\" authentication method if the 'authMethods' list is empty.
-- A spread server will refuse the connection if another with the same PrivateName is still active.
connect :: Conf -> IO (Chan R Message,Connection)
connect c = 
  let addr = fromMaybe "localhost" . address $ c
      port' = fromMaybe dEFAULT_SPREAD_PORT . port $ c
      (authnames,authmethods) = unzip $ let l = authMethods c in if null l then [nullAuthMethod] else l
  in bracketOnError (withSocketsDo $ connectTo addr (PortNumber port')) hClose $ \h -> do
    let hget = Bs.hGet h
    writePut (mkConnectMsg (desiredName c) (priority c) (groupMembership c)) h
    authlistlen <- hGetByte h
    checkLen authlistlen
    authlist <- map AN . B.split ' ' <$> hget authlistlen
    checkAuthNames authnames authlist
    writePut (putAuthNames authnames) h
    results <- mapM ($h) authmethods
    checkAuths results
    checkAccepted =<< hGetByte h
    [major,minor,patch,glen] <- map fromIntegral . Bs.unpack <$> hget 4
    checkVersion major minor patch
    mprvg <- mkPrivateGroup <$> hget glen
    makeConnection mprvg h
    where checkLen l = when ( l > (mAX_AUTH_NAME * mAX_AUTH_METHODS)) $ fail $ "connect: illegal value in authlistlen " ++ show l
          checkAuthNames nms list = unless (and $ map (`elem` list) nms) $ fail "connect: chosen authentication method is not permitted by daemon"
          checkAuths r = unless (and r) $ fail "connect: authentication of connection failed"
          checkAccepted v = unless (v == aCCEPT_SESSION) $ fail "session rejected"
          checkVersion maj min p = let val = maj*10000+min*100+p
                                   in do when (val < 30100) $ fail "old spread version, not supported"
                                         when (val < 30800 && priority c) $ fail "old spread version, priority not supported."
          makeConnection prvg h = do (r,w) <- C.newChan
                                     m <- newEmptyMVar
                                     tid <- forkIO $ receiver w m h prvg 
                                     return $ (r,C prvg w h tid m)

receiver :: Chan W Message -> MVar () -> Handle -> PrivateGroup -> IO ()
receiver c m h prvg = let
    recv = receive_internal h prvg
    putc = C.writeChan c
    p = readMVar m
    loop = p >> recv >>= putc >> loop
    in handle (\_ -> return ()) $ loop `finally` (hClose h >> C.closeChan c)

hGetByte :: Handle -> IO Int
hGetByte h = alloca $ \p -> 
            do n <- hGetBuf h p 1; 
               if n /= 1 
                 then ioError $ mkIOError eofErrorType "hGetByte" (Just h) Nothing 
                 else fromIntegral <$> (peek p :: IO Word8)


putAuthNames :: [AuthName] -> Put
putAuthNames xs = mapM_ (\(AN s) -> putPadded mAX_AUTH_NAME s) . take mAX_AUTH_METHODS $ (xs ++ repeat (AN Bs.empty))

mkConnectMsg :: PrivateName -> Bool -> Bool -> Put
mkConnectMsg p priority gmemb = do 
  mapM_ putWord8 [sP_MAJOR_VERSION,sP_MINOR_VERSION,sP_PATCH_VERSION]
  let setIf b n = if b then (.|. n) else id
  putWord8 (setIf gmemb 0x01 . setIf priority 0x10 $ 0)
  let pn = privateName p
  putWord8 . fromIntegral . B.length $ pn
  putByteString pn

writePut p h = L.hPut h (runPut p) >> hFlush h
-- | Sends a disconnection message to the server, which will close the connection.
disconnect :: Connection -> IO ()
disconnect c = let p = privateGroup c in sendInternal (Kill p) c

-- | Messages received from now on will be available on the returned 'Chan'
getDupedChan :: Connection -> IO (Chan R Message)
getDupedChan c = C.forkChan (chan c)

-- | Start fetching messages from the network, returns True if it was stopped.
startReceive :: Connection -> IO Bool
startReceive c = tryPutMVar (sync c) ()

-- | Stop fetching messages from the network (at most one more message can be read)
-- , returns True if it was started.
stopReceive :: Connection -> IO Bool
stopReceive c = maybe False (const True) <$> tryTakeMVar (sync c)
-- | Joins a group, the server will send a 'Reg'.
join :: Group -> Connection -> IO ()
join g = sendInternal (Joining g)
-- | Leaves a group, the server will send a 'SelfLeave'.
leave :: Group -> Connection -> IO ()
leave g = sendInternal (Leaving g)

-- | Send a regular message.
send :: OutMsg -> Connection -> IO ()
send = sendInternal 

sendInternal m c = let p = privateGroup c in multicast_internal p m (getHandle c) >> return ()