{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveGeneric #-}

module Spark.Core.Internal.ContextIOInternal(
  returnPure,
  createSparkSession,
  createSparkSession',
  executeCommand1,
  executeCommand1',
  checkDataStamps,
  updateSourceInfo,
  createComputation,
  computationStats
) where

import Control.Concurrent(threadDelay)
import Control.Lens((^.))
import Control.Monad.State(mapStateT, get)
import Control.Monad(forM, forM_)
import Data.Aeson(toJSON, FromJSON)
import Data.Functor.Identity(runIdentity)
import Data.Text(Text, pack)
import qualified Data.Text as T
import qualified Network.Wreq as W
import Network.Wreq(responseBody)
import Control.Monad.Trans(lift)
import Control.Monad.Logger(runStdoutLoggingT, LoggingT, logDebugN, logInfoN, MonadLoggerIO)
import System.Random(randomIO)
import Data.Word(Word8)
import Data.Maybe(mapMaybe)
import Control.Monad.IO.Class
import GHC.Generics
-- import Formatting
import Network.Wreq.Types(Postable)
import Data.ByteString.Lazy(ByteString)
import qualified Data.HashMap.Strict as HM
import qualified Data.HashSet as HS

import Spark.Core.Dataset
import Spark.Core.Internal.Client
import Spark.Core.Internal.ContextInternal
import Spark.Core.Internal.ContextStructures
import Spark.Core.Internal.DatasetFunctions(untypedLocalData, nodePath)
import Spark.Core.Internal.DatasetStructures(UntypedLocalData)
import Spark.Core.Internal.OpStructures(DataInputStamp(..))
import Spark.Core.Row
import Spark.Core.StructuresInternal
import Spark.Core.Try
import Spark.Core.Types
import Spark.Core.Internal.Utilities

returnPure :: forall a. SparkStatePure a -> SparkState a
returnPure p = lift $ mapStateT (return . runIdentity) p

{- | Creates a new Spark session.

This session is unique, and it will not try to reconnect to an existing
session.
-}
createSparkSession :: (MonadLoggerIO m) => SparkSessionConf -> m SparkSession
createSparkSession conf = do
  sessionName <- case confRequestedSessionName conf of
    "" -> liftIO _randomSessionName
    x -> pure x
  let session = _createSparkSession conf sessionName 0
  let url = _sessionEndPoint session
  logDebugN $ "Creating spark session at url: " <> url
  -- TODO get the current counter from remote
  _ <- _ensureSession session
  return session

{-| Convenience function for simple cases that do not require monad stacks.
-}
createSparkSession' :: SparkSessionConf -> IO SparkSession
createSparkSession' = _runLogger . createSparkSession

{- |
Executes a command:
- performs the transforms and the optimizations in the pure state
- sends the computation to the backend
- waits for the terminal nodes to reach a final state
- commits the final results to the state

If any failure is detected that is internal to Karps, it returns an error.
If the error comes from an underlying library (http stack, programming failure),
an exception may be thrown instead.
-}
executeCommand1 :: forall a. (FromSQL a) =>
  LocalData a -> SparkState (Try a)
executeCommand1 ld = do
    tcell <- executeCommand1' (untypedLocalData ld)
    return $ tcell >>= (tryEither . cellToValue)

-- The main function to launch computations.
executeCommand1' :: UntypedLocalData -> SparkState (Try Cell)
executeCommand1' ld = do
  logDebugN $ "executeCommand1': computing observable " <> show' ld
  -- Retrieve the computation graph
  let cgt = buildComputationGraph ld
  _ret cgt $ \cg -> do
    cgWithSourceT <- updateSourceInfo cg
    _ret cgWithSourceT $ \cgWithSource -> do
      -- Update the computations with the stamps, and build the computation.
      compt <- createComputation cgWithSource
      _ret compt $ \comp -> do
        -- Run the computation.
        session <- get
        _ <- _sendComputation session comp
        waitForCompletion comp

waitForCompletion :: Computation -> SparkState (Try Cell)
waitForCompletion comp = do
  -- We track all the observables, instead of simply the targets.
  let obss = getObservables comp
  let trackedNodes = obss <&> \n ->
        (nodeId n, nodePath n,
         unSQLType (nodeType n), nodePath n)
  nrs' <- _computationMultiStatus (cId comp) HS.empty trackedNodes
  -- Find the main result again in the list of everything.
  -- TODO: we actually do not need all the results, just target nodes.
  let targetNid = case cTerminalNodeIds comp of
        [nid] -> nid
        -- TODO: handle the case of multiple terminal targets
        l -> missing $ "waitForCompletion: missing multilist case with " <> show' l
  case filter (\z -> fst z == targetNid) nrs' of
    [(_, tc)] -> return tc
    l -> return $ tryError $ "Expected single result, got " <> show' l

{-| Exposed for debugging -}
computationStats ::
  ComputationID -> SparkState BatchComputationResult
computationStats cid = do
  logDebugN $ "computationStats: stats for " <> show' cid
  session <- get
  _computationStats session cid

{-| Exposed for debugging -}
createComputation :: ComputeGraph -> SparkState (Try Computation)
createComputation cg = returnPure $ prepareComputation cg

{-| Exposed for debugging -}
updateSourceInfo :: ComputeGraph -> SparkState (Try ComputeGraph)
updateSourceInfo cg = do
  let sources = inputSourcesRead cg
  if null sources
  then return (pure cg)
  else do
    logDebugN $ "updateSourceInfo: found sources " <> show' sources
    -- Get the source stamps. Any error at this point is considered fatal.
    stampsRet <- checkDataStamps sources
    logDebugN $ "updateSourceInfo: retrieved stamps " <> show' stampsRet
    let stampst = sequence $ _f <$> stampsRet
    let cgt = insertSourceInfo cg =<< stampst
    return cgt


_ret :: Try a -> (a -> SparkState (Try b)) -> SparkState (Try b)
_ret (Left x) _ = return (Left x)
_ret (Right x) f = f x

_f :: (a, Try b) -> Try (a, b)
_f (x, t) = case t of
                Right u -> Right (x, u)
                Left e -> Left e

data StampReturn = StampReturn {
  stampReturnPath :: !Text,
  stampReturnError :: !(Maybe Text),
  stampReturn :: !(Maybe Text)
} deriving (Eq, Show, Generic)

instance FromJSON StampReturn

{-| Given a list of paths, checks each of these paths on the file system of the
given Spark cluster to infer the status of these resources.

The primary role of this function is to check how recent these resources are
compared to some previous usage.
-}
checkDataStamps :: [HdfsPath] -> SparkState [(HdfsPath, Try DataInputStamp)]
checkDataStamps l = do
  session <- get
  let url = _sessionResourceCheck session
  status <- liftIO (W.asJSON =<< W.post (T.unpack url) (toJSON l) :: IO (W.Response [StampReturn]))
  let s = status ^. responseBody
  return $ mapMaybe _parseStamp s


_parseStamp :: StampReturn -> Maybe (HdfsPath, Try DataInputStamp)
_parseStamp sr = case (stampReturn sr, stampReturnError sr) of
  (Just s, _) -> pure (HdfsPath (stampReturnPath sr), pure (DataInputStamp s))
  (Nothing, Just err) -> pure (HdfsPath (stampReturnPath sr), tryError err)
  _ -> Nothing -- No error being returned for now, we just discard it.

_randomSessionName :: IO Text
_randomSessionName = do
  ws <- forM [1..10] (\(_::Int) -> randomIO :: IO Word8)
  let ints = (`mod` 10) <$> ws
  return . T.pack $ "session" ++ concat (show <$> ints)

type DefLogger a = LoggingT IO a

_runLogger :: DefLogger a -> IO a
_runLogger = runStdoutLoggingT

_post :: (MonadIO m, Postable a) =>
  Text -> a -> m (W.Response ByteString)
_post url = liftIO . W.post (T.unpack url)

_get :: (MonadIO m) =>
  Text -> m (W.Response ByteString)
_get url = liftIO $ W.get (T.unpack url)

-- TODO move to more general utilities
-- Performs repeated polling until the result can be converted
-- to a certain other type.
-- Int controls the delay in milliseconds between each poll.
_pollMonad :: (MonadIO m) => m a -> Int -> (a -> Maybe b) -> m b
_pollMonad rec delayMillis check = do
  curr <- rec
  case check curr of
    Just res -> return res
    Nothing -> do
      _ <- liftIO $ threadDelay (delayMillis * 1000)
      _pollMonad rec delayMillis check


-- Creates a new session from a string containing a session ID.
_createSparkSession :: SparkSessionConf -> Text -> Integer -> SparkSession
_createSparkSession conf sessionId idx =
  SparkSession conf sid idx HM.empty where
    sid = LocalSessionId sessionId

_port :: SparkSession -> Text
_port = pack . show . confPort . ssConf

-- The URL of the end point
_sessionEndPoint :: SparkSession -> Text
_sessionEndPoint sess =
  let port = _port sess
      sid = (unLocalSession . ssId) sess
  in
    T.concat [
      (confEndPoint . ssConf) sess, ":", port,
      "/sessions/", sid]

_sessionResourceCheck :: SparkSession -> Text
_sessionResourceCheck sess =
  let port = _port sess
      sid = (unLocalSession . ssId) sess
  in
    T.concat [
      (confEndPoint . ssConf) sess, ":", port,
      "/resources_status/", sid]

_sessionPortText :: SparkSession -> Text
_sessionPortText = pack . show . confPort . ssConf

-- The URL of the computation end point
_compEndPoint :: SparkSession -> ComputationID -> Text
_compEndPoint sess compId =
  let port = _sessionPortText sess
      sid = (unLocalSession . ssId) sess
      cid = unComputationID compId
  in
    T.concat [
      (confEndPoint . ssConf) sess, ":", port,
      "/computations/", sid, "/", cid]

-- The URL of the status of a computation
_compEndPointStatus :: SparkSession -> ComputationID -> Text
_compEndPointStatus sess compId =
  let port = _sessionPortText sess
      sid = (unLocalSession . ssId) sess
      cid = unComputationID compId
  in
    T.concat [
      (confEndPoint . ssConf) sess, ":", port,
      "/computations_status/", sid, "/", cid]

-- Ensures that the server has instantiated a session with the given ID.
_ensureSession :: (MonadLoggerIO m) => SparkSession -> m ()
_ensureSession session = do
  let url = _sessionEndPoint session <> "/create"
  _ <- _post url (toJSON 'a')
  return ()


_sendComputation :: (MonadLoggerIO m) => SparkSession -> Computation -> m ()
_sendComputation session comp = do
  let base' = _compEndPoint session (cId comp)
  let url = base' <> "/create"
  logInfoN $ "Sending computations at url: " <> url <> "with nodes: " <> show' (cNodes comp)
  _ <- _post url (toJSON (cNodes comp))
  return ()

_computationStatus :: (MonadLoggerIO m) =>
  SparkSession -> ComputationID -> NodePath -> m PossibleNodeStatus
_computationStatus session compId npath = do
  let base' = _compEndPointStatus session compId
  let rest = prettyNodePath npath
  let url = base' <> rest
  _ <- _get url
  status <- liftIO (W.asJSON =<< W.get (T.unpack url) :: IO (W.Response PossibleNodeStatus))
  let s = status ^. responseBody
  return s

-- TODO: not sure how this works when trying to make a fix point: is it going to
-- blow up the 'stack'?
_computationMultiStatus ::
   -- The computation being run
  ComputationID ->
  -- The set of nodes that have been processed in this computation, and ended
  -- with a success.
  -- TODO: should we do all the nodes processed in this computation?
  HS.HashSet NodeId ->
  -- The list of nodes for which we have not had completion information so far.
  [(NodeId, NodePath, DataType, NodePath)] ->
  SparkState [(NodeId, Try Cell)]
_computationMultiStatus _ _ [] = return []
_computationMultiStatus cid done l = do
  session <- get
  -- Find the nodes that still need processing (i.e. that have not previously
  -- finished with a success)
  let f (nid, _, _, _) = not $ HS.member nid done
  let needsProcessing = filter f l
  -- Poll a bunch of nodes to try to get a status update.
  let statusl = _try (_computationStatus session cid) <$> needsProcessing :: [SparkState (NodeId, NodePath, DataType, PossibleNodeStatus)]
  status <- sequence statusl
  -- Update the state with the new data
  (updated, statusUpdate) <- returnPure $ updateCache cid status
  forM_ statusUpdate $ \(p, s) -> case s of
      NodeCacheSuccess ->
        logInfoN $ "_computationMultiStatus: " <> prettyNodePath p <> " finished"
      NodeCacheError ->
        logInfoN $ "_computationMultiStatus: " <> prettyNodePath p <> " finished (ERROR)"
      NodeCacheRunning ->
        logInfoN $ "_computationMultiStatus: " <> prettyNodePath p <> " running"
  -- Filter out the updated nodes, so that we do not ask for them again.
  let updatedNids = HS.union done (HS.fromList (fst <$> updated))
  let g (nid, _, _, _) = not $ HS.member nid updatedNids
  let stillNeedsProcessing = filter g needsProcessing
  -- Do not block uselessly if we have nothing else to do
  if null stillNeedsProcessing
  then return updated
  else do
    let delayMillis = confPollingIntervalMillis $ ssConf session
    _ <- liftIO $ threadDelay (delayMillis * 1000)
    -- TODO: this chaining is certainly not tail-recursive
    -- How much of a memory leak is it?
    reminder <- _computationMultiStatus cid updatedNids stillNeedsProcessing
    return $ updated ++ reminder

_try :: (Monad m) => (y -> m z) -> (x, x', x'', y) -> m (x, x', x'', z)
_try f (x, x', x'', y) = f y <&> \z -> (x, x', x'', z)

_computationStats :: (MonadLoggerIO m) =>
  SparkSession -> ComputationID -> m BatchComputationResult
_computationStats session compId = do
  let url = _compEndPointStatus session compId <> "/" -- The final / is mandatory
  logDebugN $ "Sending computations stats request at url: " <> url
  stats <- liftIO (W.asJSON =<< W.get (T.unpack url) :: IO (W.Response BatchComputationResult))
  let s = stats ^. responseBody
  return s


_waitSingleComputation :: (MonadLoggerIO m) =>
  SparkSession -> Computation -> NodePath -> m FinalResult
_waitSingleComputation session comp npath =
  let
    extract :: PossibleNodeStatus -> Maybe FinalResult
    extract (NodeFinishedSuccess (Just s) _) = Just $ Right s
    extract (NodeFinishedFailure f) = Just $ Left f
    extract _ = Nothing
    getStatus = _computationStatus session (cId comp) npath
    i = confPollingIntervalMillis $ ssConf session
  in
    _pollMonad getStatus i extract