module Control.Distributed.Process.Zookeeper
   (
   
     bootstrap
   , bootstrapWith
   , zkController
   , zkControllerWith
   
   , registerZK
   , getCapable
   , nsendCapable
   , registerCandidate
   , whereisGlobal
   
   , getPeers
   , nsendPeers
   
   , Config(..)
   , defaultConfig
   
   , nolog
   , sayTrace
   , waitController
   ) where
import           Database.Zookeeper                       (AclList (..),
                                                           CreateFlag (..),
                                                           Event (..), Watcher,
                                                           ZKError (..),
                                                           Zookeeper)
import qualified Database.Zookeeper                       as ZK
import           Control.Applicative                      ((<$>))
import           Control.Concurrent                       (myThreadId,
                                                           newEmptyMVar,
                                                           putMVar, putMVar,
                                                           takeMVar, takeMVar,
                                                           threadDelay)
import Control.DeepSeq (deepseq, NFData(..))
import           Control.Exception                        (bracket, throwIO,
                                                           throwTo)
import           Control.Monad                            (forM, join, void)
import           Control.Monad.Except                     (ExceptT (..), lift)
import           Control.Monad.IO.Class                   (MonadIO)
import           Control.Monad.Trans.Except               (runExceptT, throwE)
import           Data.Binary                              (Binary, decode,
                                                           encode)
import           Data.ByteString                          (ByteString)
import qualified Data.ByteString                          as BS
import qualified Data.ByteString.Lazy                     as BL
import           Data.Foldable                            (forM_)
import           Data.List                                (isPrefixOf, sort)
import           Data.Map.Strict                          (Map)
import qualified Data.Map.Strict                          as Map
import           Data.Maybe                               (fromMaybe)
import           Data.Monoid                              (mempty)
import           Data.Typeable                            (Typeable)
import           GHC.Generics                             (Generic)
import           Control.Distributed.Process              hiding (bracket,
                                                           proxy)
import           Control.Distributed.Process.Management
import           Control.Distributed.Process.Node         (newLocalNode,
                                                           runProcess)
import           Control.Distributed.Process.Serializable
import           Network                                  (HostName)
import           Network.Socket                           (ServiceName)
import           Network.Transport                        (closeTransport)
import           Network.Transport.TCP                    (createTransport,
                                                           defaultTCPParameters)
data Command = Register !String !ProcessId !(SendPort (Either String ()))
             | GlobalCandidate !String !ProcessId !(SendPort (Either String ProcessId))
             | CheckCandidate !String
             | ClearCache !String
             | GetRegistered !String !(SendPort [ProcessId])
             | GetGlobal !String !(SendPort (Maybe ProcessId))
             | Exit
    deriving (Show, Typeable, Generic)
instance Binary Command
instance NFData Command
data Elect = Elect deriving (Typeable, Generic)
instance Binary Elect
instance NFData Elect
data State = State
    {
      nodeCache  :: Map String [ProcessId]
    
    , monPids    :: Map ProcessId ([String],[String])
    , candidates :: Map String (String, ProcessId)
    , spid       :: ProcessId
    , proxy      :: Process () -> IO ()
    , conn       :: Zookeeper
    }
instance Show State where
    show State{..} = show ("nodes", nodeCache, "monPids", monPids)
data Config = Config
    {
      
      
      
      
      registerPrefix :: String
      
      
    , logTrace       :: String -> Process ()
      
      
    , logError       :: String -> Process ()
      
      
    , zLogLevel      :: ZK.ZLogLevel
      
      
      
      
    , acl            :: ZK.AclList
      
    , credentials    :: Maybe (ZK.Scheme, ByteString)
    }
nolog :: String -> Process ()
nolog m = m `deepseq` return ()
sayTrace :: String -> Process ()
sayTrace = say . ("[C.D.P.Zookeeper: TRACE] - " ++)
defaultConfig :: Config
defaultConfig = Config {
      registerPrefix = ""
    , logTrace = nolog
    , logError = say . ("[C.D.P.Zookeeper: ERROR] - " ++)
    , zLogLevel = ZK.ZLogWarn
    , acl = OpenAclUnsafe
    , credentials = Nothing
    }
registerZK :: String -> ProcessId -> Process (Either String ())
registerZK name rpid = callZK $ Register name rpid
getPeers :: Process [NodeId]
getPeers = fmap processNodeId <$> callZK (GetRegistered controllersNode)
getCapable :: String -> Process [ProcessId]
getCapable name =
    callZK (GetRegistered (servicesNode </> name))
nsendPeers :: Serializable a => String -> a -> Process ()
nsendPeers service msg = getPeers >>= mapM_ (\peer -> nsendRemote peer service msg)
nsendCapable :: Serializable a => String -> a -> Process ()
nsendCapable service msg = getCapable service >>= mapM_ (`send` msg)
registerCandidate :: String -> Process () -> Process (Either String ProcessId)
registerCandidate name proc = stage >>= callZK . GlobalCandidate name
  where stage = spawnLocal $ (expect :: Process Elect) >> proc
whereisGlobal :: String -> Process (Maybe ProcessId)
whereisGlobal = callZK . GetGlobal
zkController :: String 
             -> Process ()
zkController = zkControllerWith defaultConfig
zkControllerWith :: Config
                 -> String 
                 -> Process ()
zkControllerWith config@Config{..} keepers =
 do run <- spawnLinkedProxy
    waitInit <- liftIO newEmptyMVar
    liftIO $ ZK.setDebugLevel zLogLevel
    mthread <- liftIO myThreadId
    liftIO $ ZK.withZookeeper keepers 1000 (Just $ inits mthread waitInit) Nothing $ \rzh ->
     do init' <- takeMVar waitInit
        case init' of
            Left reason -> run $ do logError $ "Init failed: " ++ show reason
                                    die reason
            Right ()    -> run (server rzh config)
  where
    inits _ waitInit rzh SessionEvent ZK.ConnectedState _ =
     do esetup <- runExceptT $ eauth >> dosetup
        case esetup of
            Right _ -> putMVar waitInit (Right ())
            Left reason -> putMVar waitInit (Left reason)
        where
          eauth = lift auth >>= hoistEither
          auth =
             case credentials of
                  Nothing -> return (Right ())
                  Just (scheme, creds) ->
                   do waitAuth <- newEmptyMVar
                      ZK.addAuth rzh scheme creds $ \ res ->
                         case res of
                             Left reason ->
                                 let msg = "Authentication to Zookeeper failed: " ++ show reason
                                 in putMVar waitAuth (Left msg)
                             Right () -> putMVar waitAuth (Right ())
                      takeMVar waitAuth
          dosetup =
           do esetup <- runExceptT $
               do createAssert rzh "/" Nothing acl []
                  createAssert rzh rootNode Nothing acl []
                  createAssert rzh servicesNode Nothing acl []
                  createAssert rzh controllersNode Nothing acl []
                  createAssert rzh globalsNode Nothing acl []
                  createAssert rzh (globalsNode </> "blarg") Nothing acl []
              case esetup of
                  Right _ -> return ()
                  Left reason ->
                      throwE $ "FATAL: could not create system nodes in Zookeeper: "
                             ++ show reason
    inits zkthread _ _ SessionEvent ZK.ExpiredSessionState _ =
        
        throwTo zkthread $ userError "Zookeeper session expired."
    inits _ _ _ _ _ _ = return ()
server :: Zookeeper -> Config -> Process ()
server rzh config@Config{..} =
 do pid <- getSelfPid
    register controller pid
    watchRegistration config
    proxy' <- spawnLinkedProxy
    regself <- create rzh (controllersNode </> pretty pid)
                          (pidBytes pid) acl [Ephemeral]
    case regself of
        Left reason ->
            let msg = "Could not register self with Zookeeper: " ++ show reason
            in logError msg >> die msg
        Right _ -> return ()
    let loop st =
          let recvCmd = match $ \command -> case command of
                                  Exit -> return ()
                                  _ ->
                                    do eresult <- runExceptT $ handle st command
                                       case eresult of
                                           Right st' ->
                                              do logTrace $ "State of: " ++ show st' ++ " - after - "
                                                           ++ show command
                                                 loop st'
                                           Left reason ->
                                              do logError $ "Error handling: " ++ show command
                                                            ++ " : " ++ show reason
                                                 loop st
              recvMon = match $ \(ProcessMonitorNotification _ dead _) ->
                                  reap st dead >>= loop
          in logTrace (show st) >> receiveWait [recvCmd, recvMon]
    void $ loop (State mempty mempty mempty pid proxy' rzh)
  where
    reap st@State{..} pid =
     do let (services, globals) = fromMaybe ([],[]) (Map.lookup pid monPids)
        forM_ services $ \name ->
            deleteNode (servicesNode </> name </> pretty pid)
        forM_ globals $ \name ->
            deleteNode (globalsNode </> name)
        return st{monPids = Map.delete pid monPids}
      where
        deleteNode node =
         do result <- liftIO $ ZK.delete rzh node Nothing
            case result of
                  Left reason -> logError  $ show reason
                                          ++ " - failed to delete "
                                          ++ node
                  _ -> logTrace $ "Reaped " ++ node
    handle st@State{..} (Register name rpid reply) =
     do let node = servicesNode </> name </> pretty rpid
        createAssert rzh (servicesNode </> name) Nothing acl []
        result <- create rzh node (pidBytes rpid) acl [Ephemeral]
        case result of
            Right _ -> lift $
             do logTrace $ "Registered " ++ node
                nfSendChan reply (Right ())
                void $ monitor rpid
                let apService (a',b') (a, b) = (a' ++ a, b' ++ b)
                return st{monPids = Map.insertWith apService rpid ([name],[]) monPids}
            Left reason -> lift $
             do logError $ "Failed to register name: " ++ node ++ " - " ++ show reason
                nfSendChan reply (Left $ show reason)
                return st
    handle st@State{..} (ClearCache node) =
        return st{nodeCache = Map.delete node nodeCache}
    handle st@State{..} (GetRegistered node reply) =
     do epids <- case Map.lookup node nodeCache of
                        Just pids -> return (Right pids)
                        Nothing -> getChildPids rzh node (Just $ watchCache st node)
        lift $ case epids of
            Right pids ->
             do nfSendChan reply pids
                return st{nodeCache = Map.insert node pids nodeCache}
            Left reason ->
             do logError $  "Retrieval failed for node: " ++ node ++ " - " ++ show reason
                nfSendChan reply []
                return st{nodeCache = Map.delete node nodeCache}
    handle st (GlobalCandidate n c r) = handleGlobalCandidate config st n c r
    handle st@State{..} (GetGlobal name reply) =
        let gname = globalsNode </> name in
        case Map.lookup gname nodeCache of
            Just (pid : _) -> lift $
               do nfSendChan reply (Just pid)
                  return st
            _              ->
               do elected <- getElected
                  lift $ nfSendChan reply elected
                  return $ maybe st
                                 (\pid -> st {nodeCache = Map.insert gname [pid] nodeCache})
                                 elected
      where
        getElected =
         do children <- getGlobalIds conn name
            case children of
                [] -> return Nothing
                (first: _) ->
                    Just <$> getPid conn (globalsNode </> name </> first)
                                         (Just $ watchCache st name)
    handle st@State{..} (CheckCandidate name) =
        case Map.lookup name candidates of
            Just (myid, staged) -> snd <$> mayElect st name myid staged
            Nothing             -> return st
    handle st Exit = return st 
    pretty pid = drop 6 (show pid)
handleGlobalCandidate :: Config -> State
                      -> String -> ProcessId -> SendPort (Either String ProcessId)
                      -> ExceptT ZKError Process State
handleGlobalCandidate Config{..} st@State{..} name proc reply
    | Just (myid, staged) <- Map.lookup name candidates =
          case Map.lookup (globalsNode </> name) nodeCache of
              Just (pid : _) -> lift $
                                 do exit staged "New candidate staged."
                                    nfSendChan reply (Right pid)
                                    return st {candidates = Map.insert name (myid, proc) candidates}
              _              -> respondElect myid proc (Just staged)
    | otherwise =
         do myid <- registerGlobalId proc
            respondElect myid proc Nothing
      where
        respondElect myid staged mprev = lift $
         do eresult <- runExceptT $ mayElect st name myid staged
            case eresult of
                Right (pid, st') ->
                 do nfSendChan reply (Right pid)
                    forM_ mprev (`exit` "New candidate staged.")
                    return st'
                Left reason ->
                 do nfSendChan reply (Left $ show reason)
                    return st
        registerGlobalId staged =
         do let pname = globalsNode </> name
            createAssert conn pname Nothing acl []
            node <- create conn (pname </> "")
                                (pidBytes staged)
                                acl
                                [Ephemeral, Sequence] >>= hoistEither
            return $ extractId node
          where
            extractId s = trimId (reverse s) ""
              where
                trimId [] _ = error $ "end of string without delimiter / in " ++ s
                trimId ('/' : _) str = str
                trimId (n : rest) str = trimId rest (n : str)
mayElect :: State -> String -> String -> ProcessId -> ExceptT ZKError Process (ProcessId, State)
mayElect st@State{..} name myid staged =
 do others <- getGlobalIds conn name
    let first : _ = others
    if myid == first
        then
         do lift $ nfSend staged Elect
            st' <- lift $ cacheNMonitor staged
            return (staged, st')
        else
         do let prev = findPrev others
            watchFirst <- if prev == first then return $ Just watchPid
                          else do void $ liftIO (ZK.exists
                                                    conn
                                                    (globalsNode </> name </> prev)
                                                    (Just watchPid)) >>= hoistEither
                                  return Nothing
            pid <- getPid conn (globalsNode </> name </> first) watchFirst
            return (pid, stCache)
  where
    findPrev (prev : next : rest) = if next == myid
                                       then prev
                                       else findPrev (next : rest)
    findPrev _ = error "impossible: couldn't find myself in election"
    watchPid _ ZK.DeletedEvent ZK.ConnectedState _ =
        proxy $ nfSend spid (CheckCandidate name)
    watchPid _ _ _ _ = return ()
    stCandidate = st{candidates = Map.insert name (myid, staged) candidates}
    stCache = stCandidate {nodeCache = Map.insert (globalsNode </> name) [staged] nodeCache}
    cacheNMonitor pid =
     do void $ monitor pid
        liftIO $ void $ ZK.exists conn (globalsNode </> name </> myid) (Just $ watchCache st (globalsNode </> name))
        return stCacheMon
      where
        stCacheMon =
             let apGlobal (a', b') (a, b) = (a' ++ a, b' ++ b) in
             st{ candidates = Map.delete name candidates
               , monPids = Map.insertWith apGlobal
                                          pid
                                          ([],[name </> myid]) monPids
               , nodeCache = Map.insert (globalsNode </> name) [pid] nodeCache
               }
getGlobalIds :: MonadIO m
             => Zookeeper -> String -> ExceptT ZKError m [String]
getGlobalIds conn name = liftIO $
 do echildren <- ZK.getChildren conn (globalsNode </> name) Nothing
    case echildren of
        Left NoNodeError -> return []
        Left reason ->
         do putStrLn $ "Could not fetch globals for " ++ name ++ " - " ++ show reason
            return []
        Right children -> return (sort children)
getPid :: MonadIO m
       => Zookeeper -> String -> Maybe Watcher
       -> ExceptT ZKError m ProcessId
getPid conn name watcher =
 do res <- liftIO $ ZK.get conn name watcher
    case res of
        Right (Just bs, _) -> return (decode $ BL.fromStrict bs)
        Right _ -> throwE NothingError
        Left reason -> throwE reason
getChildPids :: MonadIO m
             => Zookeeper -> String -> Maybe Watcher
             -> m (Either ZKError [ProcessId])
getChildPids rzh node watcher = liftIO $
 do enodes' <- ZK.getChildren rzh node watcher
    case enodes' of
        Left NoNodeError -> return $ Right []
        _ ->
            runExceptT $
             do children <- hoistEither enodes'
                forM children $ \child ->
                 do eresult <- liftIO $ ZK.get rzh (node </> child) Nothing
                    case eresult of
                       Left reason        -> throwE reason
                       Right (Nothing, _) -> throwE NothingError
                       Right (Just bs, _) -> return (decode $ BL.fromStrict bs)
watchCache :: State -> String -> a -> b -> ZK.State -> c -> IO ()
watchCache State{..} node _ _ ZK.ConnectedState _ =
    proxy $ nfSend spid (ClearCache node)
watchCache _ _ _ _ _ _ = return ()
watchRegistration :: Config -> Process ()
watchRegistration Config{..} = do
    let initState = [] :: [MxEvent]
    void $ mxAgent (MxAgentId "zookeeper:name:listener") initState [
        mxSink $ \ev -> do
           let act =
                 case ev of
                   (MxRegistered _ "zookeeper:name:listener") -> return ()
                   (MxRegistered pid name')
                        | prefixed name' -> liftMX $
                                do ereg <- registerZK name' pid
                                   case ereg of
                                       Left reason ->
                                          logError $ "Automatic registration failed for name: "
                                                ++ name' ++ " - " ++ reason
                                       _ -> return ()
                        | otherwise -> return ()
                   _                   -> return ()
           act >> mxReady ]
    liftIO $ threadDelay 10000
  where prefixed = isPrefixOf registerPrefix
waitController :: Int -> Process (Maybe ())
waitController timeout =
 do res <- whereis controller
    case res of
        Nothing ->
            do let timeleft = timeout  10000
               if timeleft <= 0 then return Nothing
                  else do liftIO (threadDelay 10000)
                          waitController timeleft
        Just _ -> return (Just ())
stopController :: Process ()
stopController =
 do res <- whereis controller
    case res of
        Nothing -> say "Could not find controller to stop it."
        Just pid -> nfSend pid Exit
hoistEither :: Monad m => Either e a -> ExceptT e m a
hoistEither = ExceptT . return
(</>) :: String -> String -> String
l </> r = l ++ "/" ++ r
servicesNode :: String
servicesNode = rootNode </> "services"
controllersNode :: String
controllersNode = rootNode </> "controllers"
globalsNode :: String
globalsNode = rootNode </> "globals"
rootNode :: String
rootNode = "/distributed-process"
createAssert :: MonadIO m
             => Zookeeper -> String -> Maybe BS.ByteString -> AclList -> [CreateFlag]
             -> ExceptT ZKError m ()
createAssert z n d a f = create z n d a f >>= eitherExists
  where
    eitherExists (Right _)  = return ()
    eitherExists (Left NodeExistsError) = return ()
    eitherExists (Left reason) = throwE reason
pidBytes :: ProcessId -> Maybe BS.ByteString
pidBytes = Just . BL.toStrict . encode
controller :: String
controller = "zookeeper:controller"
nfSendChan :: (Binary a, Typeable a, NFData a) => SendPort a -> a -> Process ()
nfSendChan !port !msg = unsafeSendChan port (msg `deepseq` msg)
nfSend :: (Binary a, Typeable a, NFData a) => ProcessId -> a -> Process ()
nfSend !pid !msg = unsafeSend pid (msg `deepseq` msg)
callZK :: Serializable a => (SendPort a -> Command) -> Process a
callZK command =
    do Just pid <- whereis controller
       (nfSendCh, replyCh) <- newChan
       link pid
       nfSend pid (command nfSendCh)
       result <- receiveChan replyCh
       unlink pid
       return result
spawnLinkedProxy :: Process (Process a -> IO a)
spawnLinkedProxy =
 do action <- liftIO newEmptyMVar
    result <- liftIO newEmptyMVar
    self <- getSelfPid
    pid <- spawnLocal $
        let loop = join (liftIO $ takeMVar action)
                   >>= liftIO . putMVar result
                   >> loop in link self >> loop
    link pid
    return (\f -> putMVar action f >> takeMVar result)
create :: MonadIO m => Zookeeper -> String -> Maybe BS.ByteString
       -> AclList -> [CreateFlag] -> m (Either ZKError String)
create z n d a = liftIO . ZK.create z n d a
bootstrap :: HostName 
          -> ServiceName 
          -> String 
          -> RemoteTable 
          -> Process () 
          -> IO ()
bootstrap = bootstrapWith defaultConfig
bootstrapWith :: Config 
              -> HostName 
              -> ServiceName 
              -> String 
              -> RemoteTable 
              -> Process () 
              -> IO ()
bootstrapWith config host port zservs rtable proc =
    bracket openTransport closeTransport exec
  where
    openTransport =
     do mtcp <- createTransport host port defaultTCPParameters
        case mtcp of
            Right tcp   -> return tcp
            Left reason -> throwIO reason
    exec tcp =
       do node <- newLocalNode tcp rtable
          runProcess node $
           do zkpid <- spawnLocal $ zkControllerWith config zservs
              link zkpid
              found <- waitController 100000
              case found of
                Nothing -> die "Timeout waiting for Zookeeper controller to start."
                Just () -> proc
              stopController