-- Copyright (c) 2014-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is distributed under the terms of a BSD license,
-- found in the LICENSE file.

{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Implementation of data-fetching operations.  Most users should
-- import "Haxl.Core" instead.
--
module Haxl.Core.Fetch
  ( dataFetch
  , dataFetchWithShow
  , uncachedRequest
  , cacheResult
  , cacheResultWithShow
  , cacheRequest
  , performFetches
  , performRequestStore
  , ShowReq
  ) where

import Control.Concurrent.STM
import Control.Exception as Exception
import Control.Monad
import Data.Either
import Data.Hashable
import Data.IORef
import Data.Int
import Data.List
import Data.Monoid
import Data.Typeable
import Data.Text (Text)
import qualified Data.Text as Text
import Text.Printf
#ifdef PROFILING
import GHC.Stack
#endif

import Haxl.Core.DataSource
import Haxl.Core.DataCache as DataCache
import Haxl.Core.Exception
import Haxl.Core.Flags
import Haxl.Core.Monad
import Haxl.Core.Profile
import Haxl.Core.RequestStore
import Haxl.Core.ShowP
import Haxl.Core.Stats
import Haxl.Core.StateStore
import Haxl.Core.Util

-- -----------------------------------------------------------------------------
-- Data fetching and caching

-- | Possible responses when checking the cache.
data CacheResult u a
  -- | The request hadn't been seen until now.
  = Uncached
       (ResultVar a)
       {-# UNPACK #-} !(IVar u a)

  -- | The request has been seen before, but its result has not yet been
  -- fetched.
  | CachedNotFetched
      {-# UNPACK #-} !(IVar u a)

  -- | The request has been seen before, and its result has already been
  -- fetched.
  | Cached (ResultVal a)


-- | Show functions for request and its result.
type ShowReq r a = (r a -> String, a -> String)

-- Note [showFn]
--
-- Occasionally, for tracing purposes or generating exceptions, we need to
-- call 'show' on the request in a place where we *cannot* have a Show
-- dictionary. (Because the function is a worker which is called by one of
-- the *WithShow variants that take explicit show functions via a ShowReq
-- argument.) None of the functions that does this is exported, so this is
-- hidden from the Haxl user.

cachedWithInsert
  :: forall r a u .
     Typeable (r a)
  => (r a -> String)    -- See Note [showFn]
  -> (r a -> IVar u a -> DataCache (IVar u) -> DataCache (IVar u))
  -> Env u -> r a -> IO (CacheResult u a)
cachedWithInsert showFn insertFn Env{..} req = do
  cache <- readIORef cacheRef
  let
    doFetch = do
      ivar <- newIVar
      let !rvar = stdResultVar ivar completions
      writeIORef cacheRef $! insertFn req ivar cache
      return (Uncached rvar ivar)
  case DataCache.lookup req cache of
    Nothing -> doFetch
    Just (IVar cr) -> do
      e <- readIORef cr
      case e of
        IVarEmpty _ -> return (CachedNotFetched (IVar cr))
        IVarFull r -> do
          ifTrace flags 3 $ putStrLn $ case r of
            ThrowIO _ -> "Cached error: " ++ showFn req
            ThrowHaxl _ -> "Cached error: " ++ showFn req
            Ok _ -> "Cached request: " ++ showFn req
          return (Cached r)


-- | Make a ResultVar with the standard function for sending a CompletionReq
-- to the scheduler.
stdResultVar :: IVar u a -> TVar [CompleteReq u] -> ResultVar a
stdResultVar ivar completions = mkResultVar $ \r isChildThread -> do
  allocs <- if isChildThread
    then
      -- In a child thread, return the current allocation counter too,
      -- for correct tracking of allocation.
      getAllocationCounter
    else
      return 0
  atomically $ do
    cs <- readTVar completions
    writeTVar completions (CompleteReq r ivar allocs : cs)
{-# INLINE stdResultVar #-}


-- | Record the call stack for a data fetch in the Stats.  Only useful
-- when profiling.
logFetch :: Env u -> (r a -> String) -> r a -> IO ()
#ifdef PROFILING
logFetch env showFn req = do
  ifReport (flags env) 5 $ do
    stack <- currentCallStack
    modifyIORef' (statsRef env) $ \(Stats s) ->
      Stats (FetchCall (showFn req) stack : s)
#else
logFetch _ _ _ = return ()
#endif

-- | Performs actual fetching of data for a 'Request' from a 'DataSource'.
dataFetch :: (DataSource u r, Request r a) => r a -> GenHaxl u a
dataFetch = dataFetchWithInsert show DataCache.insert

-- | Performs actual fetching of data for a 'Request' from a 'DataSource', using
-- the given show functions for requests and their results.
dataFetchWithShow
  :: (DataSource u r, Eq (r a), Hashable (r a), Typeable (r a))
  => ShowReq r a
  -> r a -> GenHaxl u a
dataFetchWithShow (showReq, showRes) = dataFetchWithInsert showReq
  (DataCache.insertWithShow showReq showRes)

-- | Performs actual fetching of data for a 'Request' from a 'DataSource', using
-- the given function to insert requests in the cache.
dataFetchWithInsert
  :: forall u r a
   . (DataSource u r, Eq (r a), Hashable (r a), Typeable (r a))
  => (r a -> String)    -- See Note [showFn]
  -> (r a -> IVar u a -> DataCache (IVar u) -> DataCache (IVar u))
  -> r a
  -> GenHaxl u a
dataFetchWithInsert showFn insertFn req =
  GenHaxl $ \env@Env{..} -> do
  -- First, check the cache
  res <- cachedWithInsert showFn insertFn env req
  ifProfiling flags $ addProfileFetch env req
  case res of
    -- This request has not been seen before
    Uncached rvar ivar -> do
      logFetch env showFn req
      --
      -- Check whether the data source wants to submit requests
      -- eagerly, or batch them up.
      --
      case schedulerHint userEnv :: SchedulerHint r of
        SubmitImmediately -> do
          (_,ios) <- performFetches 0 env
            [BlockedFetches [BlockedFetch req rvar]]
          when (not (null ios)) $
            error "bad data source:SubmitImmediately but returns FutureFetch"
        TryToBatch ->
          -- add the request to the RequestStore and continue
          modifyIORef' reqStoreRef $ \bs ->
            addRequest (BlockedFetch req rvar) bs
      --
      return $ Blocked ivar (Cont (getIVar ivar))

    -- Seen before but not fetched yet.  We're blocked, but we don't have
    -- to add the request to the RequestStore.
    CachedNotFetched ivar -> return $ Blocked ivar (Cont (getIVar ivar))

    -- Cached: either a result, or an exception
    Cached r -> done r

-- | A data request that is not cached.  This is not what you want for
-- normal read requests, because then multiple identical requests may
-- return different results, and this invalidates some of the
-- properties that we expect Haxl computations to respect: that data
-- fetches can be arbitrarily reordered, and identical requests can be
-- commoned up, for example.
--
-- 'uncachedRequest' is useful for performing writes, provided those
-- are done in a safe way - that is, not mixed with reads that might
-- conflict in the same Haxl computation.
--
-- if we are recording or running a test, we fallback to using dataFetch
-- This allows us to store the request in the cache when recording, which
-- allows a transparent run afterwards. Without this, the test would try to
-- call the datasource during testing and that would be an exception.
uncachedRequest :: (DataSource u r, Request r a) => r a -> GenHaxl u a
uncachedRequest req = do
  isRecordingFlag <- env (recording . flags)
  if isRecordingFlag /= 0
    then dataFetch req
    else GenHaxl $ \Env{..} -> do
      ivar <- newIVar
      let !rvar = stdResultVar ivar completions
      modifyIORef' reqStoreRef $ \bs ->
        addRequest (BlockedFetch req rvar) bs
      return $ Blocked ivar (Cont (getIVar ivar))


-- | Transparently provides caching. Useful for datasources that can
-- return immediately, but also caches values.  Exceptions thrown by
-- the IO operation (except for asynchronous exceptions) are
-- propagated into the Haxl monad and can be caught by 'catch' and
-- 'try'.
cacheResult :: Request r a => r a -> IO a -> GenHaxl u a
cacheResult = cacheResultWithInsert show DataCache.insert

-- | Transparently provides caching in the same way as 'cacheResult', but uses
-- the given functions to show requests and their results.
cacheResultWithShow
  :: (Eq (r a), Hashable (r a), Typeable (r a))
  => ShowReq r a -> r a -> IO a -> GenHaxl u a
cacheResultWithShow (showReq, showRes) = cacheResultWithInsert showReq
  (DataCache.insertWithShow showReq showRes)

-- Transparently provides caching, using the given function to insert requests
-- into the cache.
cacheResultWithInsert
  :: Typeable (r a)
  => (r a -> String)    -- See Note [showFn]
  -> (r a -> IVar u a -> DataCache (IVar u) -> DataCache (IVar u)) -> r a
  -> IO a -> GenHaxl u a
cacheResultWithInsert showFn insertFn req val = GenHaxl $ \env -> do
  let !ref = cacheRef env
  cache <- readIORef ref
  case DataCache.lookup req cache of
    Nothing -> do
      eitherResult <- Exception.try val
      case eitherResult of
        Left e -> rethrowAsyncExceptions e
        _ -> return ()
      let result = eitherToResultThrowIO eitherResult
      ivar <- newFullIVar result
      writeIORef ref $! insertFn req ivar cache
      done result
    Just (IVar cr) -> do
      e <- readIORef cr
      case e of
        IVarEmpty _ -> corruptCache
        IVarFull r -> done r
  where
    corruptCache = raise . DataSourceError $ Text.concat
      [ Text.pack (showFn req)
      , " has a corrupted cache value: these requests are meant to"
      , " return immediately without an intermediate value. Either"
      , " the cache was updated incorrectly, or you're calling"
      , " cacheResult on a query that involves a blocking fetch."
      ]

-- | Inserts a request/result pair into the cache. Throws an exception
-- if the request has already been issued, either via 'dataFetch' or
-- 'cacheRequest'.
--
-- This can be used to pre-populate the cache when running tests, to
-- avoid going to the actual data source and ensure that results are
-- deterministic.
--
cacheRequest
  :: Request req a => req a -> Either SomeException a -> GenHaxl u ()
cacheRequest request result = GenHaxl $ \env -> do
  cache <- readIORef (cacheRef env)
  case DataCache.lookup request cache of
    Nothing -> do
      cr <- newFullIVar (eitherToResult result)
      writeIORef (cacheRef env) $! DataCache.insert request cr cache
      return (Done ())

    -- It is an error if the request is already in the cache.  We can't test
    -- whether the cached result is the same without adding an Eq constraint,
    -- and we don't necessarily have Eq for all results.
    _other -> raise $
      DataSourceError "cacheRequest: request is already in the cache"

performRequestStore
   :: forall u. Int -> Env u -> RequestStore u -> IO (Int, [IO ()])
performRequestStore n env reqStore =
  performFetches n env (contents reqStore)

-- | Issues a batch of fetches in a 'RequestStore'. After
-- 'performFetches', all the requests in the 'RequestStore' are
-- complete, and all of the 'ResultVar's are full.
performFetches
  :: forall u. Int -> Env u -> [BlockedFetches u] -> IO (Int, [IO ()])
performFetches n env@Env{flags=f, statsRef=sref} jobs = do
  let !n' = n + length jobs

  t0 <- getTimestamp

  let
    roundstats =
      [ (dataSourceName (Proxy :: Proxy r), length reqs)
      | BlockedFetches (reqs :: [BlockedFetch r]) <- jobs ]

  ifTrace f 1 $
    printf "Batch data fetch (%s)\n" $
      intercalate (", "::String) $
        map (\(name,num) -> printf "%d %s" num (Text.unpack name)) roundstats

  ifTrace f 3 $
    forM_ jobs $ \(BlockedFetches reqs) ->
      forM_ reqs $ \(BlockedFetch r _) -> putStrLn (showp r)

  let
    applyFetch i (BlockedFetches (reqs :: [BlockedFetch r])) =
      case stateGet (states env) of
        Nothing ->
          return (FetchToDo reqs (SyncFetch (mapM_ (setError e))))
         where
           e :: ShowP req => req a -> DataSourceError
           e req = DataSourceError $ "data source not initialized: " <> dsName
                  <> ": "
                  <> Text.pack (showp req)
        Just state ->
          return
            $ FetchToDo reqs
            $ (if report f >= 2
                then wrapFetchInStats sref dsName (length reqs)
                else id)
            $ wrapFetchInTrace i (length reqs) dsName
            $ wrapFetchInCatch reqs
            $ fetch state f (userEnv env)
      where
        dsName = dataSourceName (Proxy :: Proxy r)

  fetches <- zipWithM applyFetch [n..] jobs

  waits <- scheduleFetches fetches

  t1 <- getTimestamp
  let roundtime = fromIntegral (t1 - t0) / 1000000 :: Double

  ifTrace f 1 $
    printf "Batch data fetch done (%.2fs)\n" (realToFrac roundtime :: Double)

  return (n', waits)

data FetchToDo where
  FetchToDo :: [BlockedFetch req] -> PerformFetch req -> FetchToDo

-- Catch exceptions arising from the data source and stuff them into
-- the appropriate requests.  We don't want any exceptions propagating
-- directly from the data sources, because we want the exception to be
-- thrown by dataFetch instead.
--
wrapFetchInCatch :: [BlockedFetch req] -> PerformFetch req -> PerformFetch req
wrapFetchInCatch reqs fetch =
  case fetch of
    SyncFetch f ->
      SyncFetch $ \reqs -> f reqs `Exception.catch` handler
    AsyncFetch f ->
      AsyncFetch $ \reqs io -> f reqs io `Exception.catch` handler
      -- this might be wrong: if the outer 'fio' throws an exception,
      -- then we don't know whether we have executed the inner 'io' or
      -- not.  If not, then we'll likely get some errors about "did
      -- not set result var" later, because we haven't executed some
      -- data fetches.  But we can't execute 'io' in the handler,
      -- because we might have already done it.  It isn't possible to
      -- do it completely right here, so we have to rely on data
      -- sources themselves to catch (synchronous) exceptions.  Async
      -- exceptions aren't a problem because we're going to rethrow
      -- them all the way to runHaxl anyway.
    FutureFetch f ->
      FutureFetch $ \reqs -> f reqs `Exception.catch` (
                     \e -> handler e >> return (return ()))
    BackgroundFetch f ->
      BackgroundFetch $ \reqs -> f reqs `Exception.catch` handler
  where
    handler :: SomeException -> IO ()
    handler e = do
      rethrowAsyncExceptions e
      mapM_ (forceError e) reqs

    -- Set the exception even if the request already had a result.
    -- Otherwise we could be discarding an exception.
    forceError e (BlockedFetch _ rvar) =
      putResult rvar (except e)


wrapFetchInStats
  :: IORef Stats
  -> Text
  -> Int
  -> PerformFetch req
  -> PerformFetch req

wrapFetchInStats !statsRef dataSource batchSize perform = do
  case perform of
    SyncFetch f ->
      SyncFetch $ \reqs -> do
        fail_ref <- newIORef 0
        (t0,t,alloc,_) <- statsForIO (f (map (addFailureCount fail_ref) reqs))
        failures <- readIORef fail_ref
        updateFetchStats t0 t alloc batchSize failures
    AsyncFetch f -> do
      AsyncFetch $ \reqs inner -> do
        inner_r <- newIORef (0, 0)
        fail_ref <- newIORef 0
        let inner' = do
              (_,t,alloc,_) <- statsForIO inner
              writeIORef inner_r (t,alloc)
            reqs' = map (addFailureCount fail_ref) reqs
        (t0, totalTime, totalAlloc, _) <- statsForIO (f reqs' inner')
        (innerTime, innerAlloc) <- readIORef inner_r
        failures <- readIORef fail_ref
        updateFetchStats t0 (totalTime - innerTime) (totalAlloc - innerAlloc)
          batchSize failures
    FutureFetch submit ->
      FutureFetch $ \reqs -> do
        fail_ref <- newIORef 0
        let reqs' = map (addFailureCount fail_ref) reqs
        (t0, submitTime, submitAlloc, wait) <- statsForIO (submit reqs')
        return $ do
          (_, waitTime, waitAlloc, _) <- statsForIO wait
          failures <- readIORef fail_ref
          updateFetchStats t0 (submitTime + waitTime) (submitAlloc + waitAlloc)
            batchSize failures
    BackgroundFetch io -> do
      BackgroundFetch $ \reqs -> do
        startTime <- getTimestamp
        io (map (addTimer startTime) reqs)
  where
    statsForIO io = do
      prevAlloc <- getAllocationCounter
      (t0,t,a) <- time io
      postAlloc <- getAllocationCounter
      return (t0,t, fromIntegral $ prevAlloc - postAlloc, a)

    addTimer t0 (BlockedFetch req (ResultVar fn)) =
      BlockedFetch req $ ResultVar $ \result isChildThread -> do
        t1 <- getTimestamp
        updateFetchStats t0 (t1 - t0)
          0 -- allocs: we can't measure this easily for BackgroundFetch
          1 -- batch size: we don't know if this is a batch or not
          (if isLeft result then 1 else 0) -- failures
        fn result isChildThread

    updateFetchStats
      :: Timestamp -> Microseconds -> Int64 -> Int -> Int -> IO ()
    updateFetchStats start time space batch failures = do
      let this = FetchStats { fetchDataSource = dataSource
                            , fetchBatchSize = batch
                            , fetchStart = start
                            , fetchDuration = time
                            , fetchSpace = space
                            , fetchFailures = failures }
      atomicModifyIORef' statsRef $ \(Stats fs) -> (Stats (this : fs), ())

    addFailureCount :: IORef Int -> BlockedFetch r -> BlockedFetch r
    addFailureCount ref (BlockedFetch req (ResultVar fn)) =
      BlockedFetch req $ ResultVar $ \result isChildThread -> do
        when (isLeft result) $ atomicModifyIORef' ref (\r -> (r+1,()))
        fn result isChildThread

wrapFetchInTrace
  :: Int
  -> Int
  -> Text
  -> PerformFetch req
  -> PerformFetch req
#ifdef EVENTLOG
wrapFetchInTrace i n dsName f =
  case f of
    SyncFetch io -> SyncFetch (wrapF "Sync" io)
    AsyncFetch fio -> AsyncFetch (wrapF "Async" . fio . unwrapF "Async")
  where
    d = Text.unpack dsName
    wrapF :: String -> IO a -> IO a
    wrapF ty = bracket_ (traceEventIO $ printf "START %d %s (%d %s)" i d n ty)
                        (traceEventIO $ printf "STOP %d %s (%d %s)" i d n ty)
    unwrapF :: String -> IO a -> IO a
    unwrapF ty = bracket_ (traceEventIO $ printf "STOP %d %s (%d %s)" i d n ty)
                          (traceEventIO $ printf "START %d %s (%d %s)" i d n ty)
#else
wrapFetchInTrace _ _ _ f = f
#endif

time :: IO a -> IO (Timestamp,Microseconds,a)
time io = do
  t0 <- getTimestamp
  a <- io
  t1 <- getTimestamp
  return (t0, t1 - t0, a)

-- | Start all the async fetches first, then perform the sync fetches before
-- getting the results of the async fetches.
scheduleFetches :: [FetchToDo] -> IO [IO ()]
scheduleFetches fetches = do
  fully_async_fetches
  waits <- future_fetches
  async_fetches sync_fetches
  return waits
 where
  fully_async_fetches :: IO ()
  fully_async_fetches =
    sequence_ [f reqs | FetchToDo reqs (BackgroundFetch f) <- fetches]

  future_fetches :: IO [IO ()]
  future_fetches = sequence [f reqs | FetchToDo reqs (FutureFetch f) <- fetches]

  async_fetches :: IO () -> IO ()
  async_fetches = compose [f reqs | FetchToDo reqs (AsyncFetch f) <- fetches]

  sync_fetches :: IO ()
  sync_fetches = sequence_ [f reqs | FetchToDo reqs (SyncFetch f) <- fetches]