module Control.Distributed.Process.Execution.Exchange.Broadcast
  (
    broadcastExchange
  , broadcastExchangeT
  , broadcastClient
  , bindToBroadcaster
  , BroadcastExchange
  ) where
import Control.Concurrent.STM (STM, atomically)
import Control.Concurrent.STM.TChan
  ( TChan
  , newBroadcastTChanIO
  , dupTChan
  , readTChan
  , writeTChan
  )
import Control.DeepSeq (NFData)
import Control.Distributed.Process
  ( Process
  , MonitorRef
  , ProcessMonitorNotification(..)
  , ProcessId
  , SendPort
  , processNodeId
  , getSelfPid
  , getSelfNode
  , liftIO
  , newChan
  , sendChan
  , unsafeSend
  , unsafeSendChan
  , receiveWait
  , match
  , matchIf
  , die
  , handleMessage
  , Match
  )
import qualified Control.Distributed.Process as P
import Control.Distributed.Process.Serializable()
import Control.Distributed.Process.Execution.Exchange.Internal
  ( startExchange
  , configureExchange
  , Message(..)
  , Exchange(..)
  , ExchangeType(..)
  , applyHandlers
  )
import Control.Distributed.Process.Extras.Internal.Types
  ( Channel
  , ServerDisconnected(..)
  )
import Control.Distributed.Process.Extras.Internal.Unsafe 
  ( PCopy
  , pCopy
  , pUnwrap
  , matchChanP
  , InputStream(Null)
  , newInputStream
  )
import Control.Monad (forM_, void)
import Data.Accessor
  ( Accessor
  , accessor
  , (^:)
  )
import Data.Binary
import qualified Data.Foldable as Foldable (toList)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Typeable (Typeable)
import GHC.Generics
data BindPort = BindPort { portClient :: !ProcessId
                         , portSend   :: !(SendPort Message)
                         } deriving (Typeable, Generic)
instance Binary BindPort where
instance NFData BindPort where
data BindSTM =
    BindSTM  { stmClient :: !ProcessId
             , stmSend   :: !(SendPort (PCopy (InputStream Message)))
             } deriving (Typeable)
data OutputStream =
    WriteChan (SendPort Message)
  | WriteSTM  (Message -> STM ())
  | NoWrite
  deriving (Typeable)
data Binding = Binding { outputStream :: !OutputStream
                       , inputStream  :: !(InputStream Message)
                       }
             | PidBinding !ProcessId
  deriving (Typeable)
data BindOk = BindOk
  deriving (Typeable, Generic)
instance Binary BindOk where
instance NFData BindOk where
data BindFail = BindFail !String
  deriving (Typeable, Generic)
instance Binary BindFail where
instance NFData BindFail where
data BindPlease = BindPlease
  deriving (Typeable, Generic)
instance Binary BindPlease where
instance NFData BindPlease where
type BroadcastClients = Map ProcessId Binding
data BroadcastEx =
  BroadcastEx { _routingTable   :: !BroadcastClients
              , channel         :: !(TChan Message)
              }
type BroadcastExchange = ExchangeType BroadcastEx
broadcastExchange :: Process Exchange
broadcastExchange = broadcastExchangeT >>= startExchange
broadcastExchangeT :: Process BroadcastExchange
broadcastExchangeT = do
  ch <- liftIO newBroadcastTChanIO
  return $ ExchangeType { name        = "BroadcastExchange"
                        , state       = BroadcastEx Map.empty ch
                        , configureEx = apiConfigure
                        , routeEx     = apiRoute
                        }
broadcastClient :: Exchange -> Process (InputStream Message)
broadcastClient ex@Exchange{..} = do
  myNode <- getSelfNode
  us     <- getSelfPid
  if processNodeId pid == myNode 
     then do (sp, rp) <- newChan
             configureExchange ex $ pCopy (BindSTM us sp)
             mRef <- P.monitor pid
             P.finally (receiveWait [ matchChanP rp
                                    , handleServerFailure mRef ])
                       (P.unmonitor mRef)
     else do (sp, rp) <- newChan :: Process (Channel Message)
             configureExchange ex $ BindPort us sp
             mRef <- P.monitor pid
             P.finally (receiveWait [
                           match (\(_ :: BindOk)   -> return $ newInputStream $ Left rp)
                         , match (\(f :: BindFail) -> die f)
                         , handleServerFailure mRef
                         ])
                       (P.unmonitor mRef)
bindToBroadcaster :: Exchange -> Process ()
bindToBroadcaster ex@Exchange{..} = do
  us <- getSelfPid
  configureExchange ex $ (BindPlease, us)
apiRoute :: BroadcastEx -> Message -> Process BroadcastEx
apiRoute ex@BroadcastEx{..} msg = do
  liftIO $ atomically $ writeTChan channel msg
  forM_ (Foldable.toList _routingTable) $ routeToClient msg
  return ex
  where
    routeToClient m (PidBinding p)  = P.forward (payload m) p
    routeToClient m b@(Binding _ _) = writeToStream (outputStream b) m
apiConfigure :: BroadcastEx -> P.Message -> Process BroadcastEx
apiConfigure ex msg = do
  
  applyHandlers ex msg $ [ \m -> handleMessage m (handleBindPort ex)
                         , \m -> handleBindSTM ex m
                         , \m -> handleMessage m (handleBindPlease ex)
                         , \m -> handleMessage m (handleMonitorSignal ex)
                         , (const $ return $ Just ex)
                         ]
  where
    handleBindPlease ex' (BindPlease, p) = do
      case lookupBinding ex' p of
        Nothing -> return $ (routingTable ^: Map.insert p (PidBinding p)) ex'
        Just _  -> return ex'
    handleMonitorSignal bx (ProcessMonitorNotification _ p _) =
      return $ (routingTable ^: Map.delete p) bx
    handleBindSTM ex'@BroadcastEx{..} msg' = do
      bind' <- pUnwrap msg' :: Process (Maybe BindSTM) 
      case bind' of
        Nothing -> return Nothing
        Just s  -> do
          let binding = lookupBinding ex' (stmClient s)
          case binding of
            Nothing -> createBinding ex' s >>= \ex'' -> handleBindSTM ex'' msg'
            Just b  -> sendBinding (stmSend s) b >> return (Just ex')
    createBinding bEx'@BroadcastEx{..} BindSTM{..} = do
      void $ P.monitor stmClient
      nch <- liftIO $ atomically $ dupTChan channel
      let istr = newInputStream $ Right (readTChan nch)
      let ostr = NoWrite 
      let bnd = Binding ostr istr
      return $ (routingTable ^: Map.insert stmClient bnd) bEx'
    sendBinding sp' bs = unsafeSendChan sp' $ pCopy (inputStream bs)
    handleBindPort :: BroadcastEx -> BindPort -> Process BroadcastEx
    handleBindPort x@BroadcastEx{..} BindPort{..} = do
      let binding = lookupBinding x portClient
      case binding of
        Just _  -> unsafeSend portClient (BindFail "DuplicateBinding") >> return x
        Nothing -> do
          let istr = Null
          let ostr = WriteChan portSend
          let bound = Binding ostr istr
          void $ P.monitor portClient
          unsafeSend portClient BindOk
          return $ (routingTable ^: Map.insert portClient bound) x
    lookupBinding BroadcastEx{..} k = Map.lookup k $ _routingTable
writeToStream :: OutputStream -> Message -> Process ()
writeToStream (WriteChan sp) = sendChan sp  
writeToStream (WriteSTM stm) = liftIO . atomically . stm
writeToStream NoWrite        = const $ return ()
handleServerFailure :: MonitorRef -> Match (InputStream Message)
handleServerFailure mRef =
  matchIf (\(ProcessMonitorNotification r _ _) -> r == mRef)
          (\(ProcessMonitorNotification _ _ d) -> die $ ServerDisconnected d)
routingTable :: Accessor BroadcastEx BroadcastClients
routingTable = accessor _routingTable (\r e -> e { _routingTable = r })