module Control.Distributed.Process.Management.Internal.Table
  ( MxTableRequest(..)
  , MxTableId(..)
  , mxTableCoordinator
  , startTableCoordinator
  , delete
  , purge
  , clear
  , set
  , get
  , fetch
  ) where
import Control.Distributed.Process.Internal.Primitives
  ( receiveWait
  , receiveChan
  , match
  , matchAny
  , matchIf
  , matchChan
  , send
  , nsend
  , sendChan
  , getSelfPid
  , link
  , monitor
  , unwrapMessage
  , newChan
  , withMonitor
  )
import Control.Distributed.Process.Internal.Types
  ( Process
  , ProcessId
  , ProcessMonitorNotification(..)
  , SendPort
  , ReceivePort
  , Message
  , unsafeCreateUnencodedMessage
  )
import Control.Distributed.Process.Management.Internal.Types
  ( MxTableId(..)
  , MxAgentId(..)
  , MxAgentStart(..)
  , Fork)
import Control.Distributed.Process.Serializable (Serializable)
import Control.Monad.IO.Class (liftIO)
import Data.Accessor (Accessor, accessor, (^=), (^:))
import Data.Binary (Binary)
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Typeable (Typeable)
import GHC.Generics
data MxTableRequest =
    Delete
  | Purge
  | Clear !String
  | Set !String !Message
  | Get !String !(SendPort (Maybe Message)) 
  deriving (Typeable, Generic)
instance Binary MxTableRequest where
data MxTableState = MxTableState { _name    :: !String
                                 , _entries :: !(Map String Message)
                                 }
type MxTables = Map MxAgentId ProcessId
mxTableCoordinator :: String
mxTableCoordinator = "mx.table.coordinator"
delete :: MxTableId -> Process ()
delete = sendReq Delete
purge :: MxTableId -> Process ()
purge = sendReq Purge
clear :: String -> MxTableId -> Process ()
clear k = sendReq (Clear k)
set :: String -> Message -> MxTableId -> Process ()
set k v = sendReq (Set k v)
fetch :: forall a. (Serializable a)
      => MxTableId
      -> String
      -> Process (Maybe a)
fetch (MxForPid pid)      key = get pid key
fetch mxId@(MxForAgent _) key = do
  (sp, rp) <- newChan :: Process (SendPort (Maybe Message),
                                  ReceivePort (Maybe Message))
  sendReq (Get key sp) mxId
  receiveChan rp >>= maybe (return Nothing)
                           (unwrapMessage :: Message -> Process (Maybe a))
get :: forall a. (Serializable a)
      => ProcessId
      -> String
      -> Process (Maybe a)
get pid key = do
  safeFetch pid key >>= maybe (return Nothing)
                              (unwrapMessage :: Message -> Process (Maybe a))
safeFetch :: ProcessId -> String -> Process (Maybe Message)
safeFetch pid key = do
  (sp, rp) <- newChan
  send pid $ Get key sp
  withMonitor pid $ do
    receiveWait [
        matchChan rp return
      , matchIf (\(ProcessMonitorNotification _ pid' _) -> pid' == pid)
                (\_ -> return $ Just (unsafeCreateUnencodedMessage ()))
      ]
sendReq :: MxTableRequest -> MxTableId -> Process ()
sendReq req tid = (resolve tid) req
resolve :: Serializable a => MxTableId -> (a -> Process ())
resolve (MxForAgent agent) = \msg -> nsend mxTableCoordinator (agent, msg)
resolve (MxForPid   pid)   = \msg -> send pid msg
startTableCoordinator :: Fork -> Process ()
startTableCoordinator fork = run Map.empty
  where
    run :: MxTables -> Process ()
    run tables =
      receiveWait [
          
          match (\(ProcessMonitorNotification _ pid _) -> do
                    return $ Map.filter (/= pid) tables)
        , match (\(MxAgentStart ch agent) -> do
                    lookupAgent tables agent >>= \(p, t) -> do
                    sendChan ch p >> return t)
        , match (\req@(agent, tReq :: MxTableRequest) -> do
                    case tReq of
                      Get k sp -> do
                        lookupAgent tables agent >>= \(p, t) -> do
                            safeFetch p k >>= sendChan sp >> return t
                      _ -> do
                        handleRequest tables req)
        , matchAny (\_ -> return tables) 
        ] >>= run
    handleRequest :: MxTables
                  -> (MxAgentId, MxTableRequest)
                  -> Process MxTables
    handleRequest tables' (agent, req) = do
      lookupAgent tables' agent >>= \(p, t) -> send p req >> return t
    lookupAgent :: MxTables -> MxAgentId -> Process (ProcessId, MxTables)
    lookupAgent tables' agentId' = do
      case Map.lookup agentId' tables' of
        Nothing -> launchNew agentId' tables'
        Just p  -> return (p, tables')
    launchNew :: MxAgentId
              -> MxTables
              -> Process (ProcessId, MxTables)
    launchNew mxId tblMap = do
      let initState = MxTableState { _name = (agentId mxId)
                                   , _entries = Map.empty
                                   }
      (pid, _) <- spawnSup $ tableHandler initState
      return $ (pid, mxId `seq` pid `seq` Map.insert mxId pid tblMap)
    spawnSup proc = do
      us   <- getSelfPid
      
      
      
      them <- liftIO $ fork $ link us >> proc
      ref  <- monitor them
      return (them, ref)
tableHandler :: MxTableState -> Process ()
tableHandler state = do
  ns <- receiveWait [
      match (handleTableRequest state)
    , matchAny (\_ -> return (Just state))
    ]
  case ns of
    Nothing -> return ()
    Just s' -> tableHandler s'
  where
    handleTableRequest _  Delete    = return Nothing
    handleTableRequest st Purge     = return $ Just $ (entries ^= Map.empty) $ st
    handleTableRequest st (Clear k) = return $ Just $ (entries ^: (k `seq` Map.delete k)) $ st
    handleTableRequest st (Set k v) = return $ Just $ (entries ^: (k `seq` v `seq` Map.insert k v)) st
    handleTableRequest st (Get k c) = getEntry k c st >> return (Just st)
getEntry :: String
         -> SendPort (Maybe Message)
         -> MxTableState
         -> Process ()
getEntry k m MxTableState{..} = do
  sendChan m =<< return (Map.lookup k _entries)
entries :: Accessor MxTableState (Map String Message)
entries = accessor _entries (\ls st -> st { _entries = ls })