{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      : Network.IRC.Client.Internal
-- Copyright   : (c) 2016 Michael Walker
-- License     : MIT
-- Maintainer  : Michael Walker <mike@barrucadu.co.uk>
-- Stability   : experimental
-- Portability : CPP, OverloadedStrings, ScopedTypeVariables
--
-- Most of the hairy code. This isn't all internal, due to messy
-- dependencies, but I've tried to make this as \"internal\" as
-- reasonably possible.
--
-- This module is NOT considered to form part of the public interface
-- of this library.
module Network.IRC.Client.Internal
  ( module Network.IRC.Client.Internal
  , module Network.IRC.Client.Internal.Lens
  , module Network.IRC.Client.Internal.Types
  ) where

import Control.Applicative ((<$>))
import Control.Concurrent (forkIO, killThread, myThreadId, threadDelay, throwTo)
import Control.Concurrent.STM (STM, atomically, readTVar, readTVarIO, writeTVar)
import Control.Concurrent.STM.TBMChan (TBMChan, closeTBMChan, isClosedTBMChan, isEmptyTBMChan, readTBMChan, writeTBMChan, newTBMChan)
import Control.Monad (forM_, unless, void, when)
import Control.Monad.Catch (SomeException, catch)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ask, runReaderT)
import Data.ByteString (ByteString)
import Data.Conduit (ConduitM, (.|), await, awaitForever, yield)
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import qualified Data.Set as S
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Time.Clock (NominalDiffTime, UTCTime, addUTCTime, diffUTCTime, getCurrentTime)
import Data.Time.Format (formatTime)
import Data.Void (Void)
import Network.IRC.Conduit (Event(..), Message(..), Source(..), floodProtector, rawMessage, toByteString)

#if MIN_VERSION_time(1,5,0)
import Data.Time.Format (defaultTimeLocale)
#else
import System.Locale (defaultTimeLocale)
#endif

import Network.IRC.Client.Internal.Lens
import Network.IRC.Client.Internal.Types
import Network.IRC.Client.Lens


-------------------------------------------------------------------------------
-- * Configuration

-- | Config to connect to a server using the supplied connection
-- function.
setupInternal
  :: (IO () -> ConduitM (Either ByteString (Event ByteString)) Void IO () -> ConduitM () (Message ByteString) IO () -> IO ())
  -- ^ Function to start the network conduits.
  -> IRC s ()
  -- ^ Connect handler
  -> (Maybe SomeException -> IRC s ())
  -- ^ Disconnect handler
  -> (Origin -> ByteString -> IO ())
  -- ^ Logging function
  -> ByteString
  -- ^ Server hostname
  -> Int
  -- ^ Server port
  -> ConnectionConfig s
setupInternal f oncon ondis logf host port_ = ConnectionConfig
  { _func         = f
  , _username     = "irc-client"
  , _realname     = "irc-client"
  , _password     = Nothing
  , _server       = host
  , _port         = port_
  , _flood        = 1
  , _timeout      = 300
  , _onconnect    = oncon
  , _ondisconnect = ondis
  , _logfunc      = logf
  }


-------------------------------------------------------------------------------
-- * Event loop

-- | The event loop.
runner :: IRC s ()
runner = do
  state <- getIRCState
  let cconf = _connectionConfig state

  -- Set the real- and user-name
  let theUser = get username cconf
  let theReal = get realname cconf
  let thePass = get password cconf

  -- Initialise the IRC session
  let initialise = flip runIRCAction state $ do
        liftIO . atomically $ writeTVar (_connectionState state) Connected
        mapM_ (\p -> sendBS $ rawMessage "PASS" [encodeUtf8 p]) thePass
        sendBS $ rawMessage "USER" [encodeUtf8 theUser, "-", "-", encodeUtf8 theReal]
        _onconnect cconf

  -- Run the event loop, and call the disconnect handler if the remote
  -- end closes the socket.
  antiflood <- liftIO $ floodProtector (_flood cconf)

  -- An IORef to keep track of the time of the last received message, to allow a local timeout.
  lastReceived <- liftIO $ newIORef =<< getCurrentTime

  squeue <- liftIO . readTVarIO $ _sendqueue state

  let source = sourceTBMChan squeue
               .| antiflood
               .| logConduit (_logfunc cconf FromClient . toByteString)
  let sink   = forgetful
               .| logConduit (_logfunc cconf FromServer . _raw)
               .| eventSink lastReceived state

  -- Fork a thread to disconnect if the timeout elapses.
  mainTId <- liftIO myThreadId
  let time  = _timeout cconf
  let delay = round time
  let timeoutThread = do
        now <- getCurrentTime
        prior <- readIORef lastReceived
        if diffUTCTime now prior >= time
          then throwTo mainTId Timeout
          else threadDelay delay >> timeoutThread
  timeoutTId <- liftIO (forkIO timeoutThread)

  -- Start the client.
  (exc :: Maybe SomeException) <- liftIO $ catch
    (_func cconf initialise sink source >> killThread timeoutTId >> pure Nothing)
    (pure . Just)

  disconnect
  _ondisconnect cconf exc

-- | Forget failed decodings.
forgetful :: Monad m => ConduitM (Either a b) b m ()
forgetful = awaitForever go where
  go (Left  _) = return ()
  go (Right b) = yield b

-- | Block on receiving a message and invoke all matching handlers.
eventSink :: MonadIO m => IORef UTCTime -> IRCState s -> ConduitM (Event ByteString) o m ()
eventSink lastReceived ircstate = go where
  go = await >>= maybe (return ()) (\event -> do
    -- Record the current time.
    now <- liftIO getCurrentTime
    liftIO $ writeIORef lastReceived now

    -- Handle the event.
    let event' = decodeUtf8 <$> event
    ignored <- isIgnored ircstate event'
    unless ignored . liftIO $ do
      iconf <- snapshot instanceConfig ircstate
      forM_ (get handlers iconf) $ \(EventHandler matcher handler) ->
        maybe (pure ())
              (void . flip runIRCAction ircstate . handler (_source event'))
              (matcher event')

    -- If disconnected, do not loop.
    disconnected <- liftIO . atomically $ (==Disconnected) <$> getConnectionState ircstate
    unless disconnected go)

-- | Check if an event is ignored or not.
isIgnored :: MonadIO m => IRCState s -> Event Text -> m Bool
isIgnored ircstate ev = do
  iconf <- liftIO . readTVarIO . _instanceConfig $ ircstate
  let ignoreList = _ignore iconf

  return $
    case _source ev of
      User      n ->  (n, Nothing) `elem` ignoreList
      Channel c n -> ((n, Nothing) `elem` ignoreList) || ((n, Just c) `elem` ignoreList)
      Server  _   -> False

-- |A conduit which logs everything which goes through it.
logConduit :: MonadIO m => (a -> IO ()) -> ConduitM a a m ()
logConduit logf = awaitForever $ \x -> do
  -- Call the logging function
  liftIO $ logf x

  -- And pass the message on
  yield x

-- | Print messages to stdout, with the current time.
stdoutLogger :: Origin -> ByteString -> IO ()
stdoutLogger origin x = do
  now <- getCurrentTime

  putStrLn $ unwords
    [ formatTime defaultTimeLocale "%c" now
    , if origin == FromServer then "<---" else "--->"
    , init . tail $ show x
    ]

-- | Append messages to a file, with the current time.
fileLogger :: FilePath -> Origin -> ByteString -> IO ()
fileLogger fp origin x = do
  now <- getCurrentTime

  appendFile fp $ unwords
    [ formatTime defaultTimeLocale "%c" now
    , if origin == FromServer then "--->" else "<---"
    , init . tail $ show x
    , "\n"
    ]

-- | Do no logging.
noopLogger :: a -> b -> IO ()
noopLogger _ _ = return ()


-------------------------------------------------------------------------------
-- * Messaging

-- | Send a message as UTF-8, using TLS if enabled. This blocks if
-- messages are sent too rapidly.
send :: Message Text -> IRC s ()
send = sendBS . fmap encodeUtf8

-- | Send a message, using TLS if enabled. This blocks if messages are
-- sent too rapidly.
sendBS :: Message ByteString -> IRC s ()
sendBS msg = do
  qv <- _sendqueue <$> getIRCState
  liftIO . atomically $ flip writeTBMChan msg =<< readTVar qv


-------------------------------------------------------------------------------
-- * Disconnecting

-- | Disconnect from the server, properly tearing down the TLS session
-- (if there is one).
disconnect :: IRC s ()
disconnect = do
  s <- getIRCState

  liftIO $ do
    connState <- readTVarIO (_connectionState s)
    case connState of
      Connected -> do
        -- Set the state to @Disconnecting@
        atomically $ writeTVar (_connectionState s) Disconnecting

        -- Wait for all messages to be sent, or a minute has passed.
        timeoutBlock 60 . atomically $ do
          queue <- readTVar (_sendqueue s)
          (||) <$> isEmptyTBMChan queue <*> isClosedTBMChan queue

        -- Close the chan, which closes the sending conduit, and set
        -- the state to @Disconnected@.
        atomically $ do
          closeTBMChan =<< readTVar (_sendqueue s)
          writeTVar (_connectionState s) Disconnected

        -- Kill all managed threads. Don't wait for them to terminate
        -- here, as they might be masking exceptions and not pick up
        -- the 'Disconnect' for a while; just clear the list.
        mapM_ (`throwTo` Disconnect) =<< readTVarIO (_runningThreads s)
        atomically $ writeTVar (_runningThreads s) S.empty

      -- If already disconnected, or disconnecting, do nothing.
      _ -> pure ()

-- | Disconnect from the server (this will wait for all messages to be
-- sent, or a minute to pass), and then connect again.
--
-- This can be called after the client has already disconnected, in
-- which case it will just connect again.
--
-- Like 'runClient' and 'runClientWith', this will not return until
-- the client terminates (ie, disconnects without reconnecting).
reconnect :: IRC s ()
reconnect = do
  disconnect

  -- create a new send queue
  s <- getIRCState
  liftIO . atomically $
    writeTVar (_sendqueue s) =<< newTBMChan 16

  runner


-------------------------------------------------------------------------------
-- * Utils

-- | Interact with a client from the outside, by using its 'IRCState'.
runIRCAction :: MonadIO m => IRC s a -> IRCState s -> m a
runIRCAction ma = liftIO . runReaderT (runIRC ma)

-- | Access the client state.
getIRCState :: IRC s (IRCState s)
getIRCState = ask

-- | Get the connection state from an IRC state.
getConnectionState :: IRCState s -> STM ConnectionState
getConnectionState = readTVar . _connectionState

-- | Block until an action is successful or a timeout is reached.
timeoutBlock :: MonadIO m => NominalDiffTime -> IO Bool -> m ()
timeoutBlock dt check = liftIO $ do
  finish <- addUTCTime dt <$> getCurrentTime
  let wait = do
        now  <- getCurrentTime
        cond <- check
        when (now < finish && not cond) wait
  wait

-- | A simple wrapper around a TBMChan. As data is pushed into the
-- channel, the source will read it and pass it down the conduit
-- pipeline. When the channel is closed, the source will close also.
--
-- If the channel fills up, the pipeline will stall until values are
-- read.
--
-- From stm-conduit-3.0.0 (by Clark Gaebel <cg.wowus.cg@gmail.com>)
sourceTBMChan :: MonadIO m => TBMChan a -> ConduitM () a m ()
sourceTBMChan ch = loop where
  loop = do
    a <- liftIO . atomically $ readTBMChan ch
    case a of
      Just x  -> yield x >> loop
      Nothing -> pure ()