{-# LANGUAGE DeriveFunctor, QuasiQuotes, Rank2Types, ScopedTypeVariables #-}

-- | A module with a function to support caching the output of your parallel tasks.
module Control.Concurrent.ParallelTasks.Cache (parMapCache) where

import Control.Applicative ((<$>), (<*>))
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TMVar (newTMVar, putTMVar, takeTMVar)
import Control.DeepSeq (NFData, force)
import Control.Exception as E(catch, evaluate, IOException)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.ST (ST, runST)
import qualified Data.ByteString as BS
import Data.Int (Int64)
import Data.Serialize
import Data.String.Here.Interpolated (i)
import Data.Time.Clock (getCurrentTime)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector as V
import Data.Vector.Algorithms.Intro as MU
import System.IO (Handle, IOMode(..), SeekMode(..), hClose, hSeek, hTell, openFile, withFile)
import System.IO (hFlush, hPutStrLn)

import Control.Concurrent.ParallelTasks.Base (ExtendedParTaskOpts(..), ParTaskOpts(..), TaskOutcome(..), parallelTasks)

type Location = (Int64, Int64) -- start, length

type CacheStem = String

data CacheOutcome key = CacheHit | CacheMissSuccess | CacheMissTookTooLong key
           
isCacheHit :: CacheOutcome a -> Bool
isCacheHit CacheHit = True
isCacheHit _ = False

isCacheMissSuccess :: CacheOutcome a -> Bool
isCacheMissSuccess CacheMissSuccess = True
isCacheMissSuccess _ = False

-- Cache index file structure:
-- Int64 (number of keys)
-- Then, that many lots of:
--   key, Int64 (start, relative to payload), Int64 (end)
-- Cache payload file structure:
-- Payload

indexFile :: CacheStem -> FilePath
indexFile = (++ "-index")

payloadFile :: CacheStem -> FilePath
payloadFile = (++ "-payload")

readKeysFromCache :: (U.Unbox key, Serialize key) => CacheStem -> IO (U.Vector (key, Location))
readKeysFromCache cacheStem = (readKeysFromCache' <$> BS.readFile (indexFile cacheStem)) `E.catch` (\(_e :: IOException) -> return $ U.fromList [])

readKeysFromCache' :: (U.Unbox key, Serialize key) => BS.ByteString -> U.Vector (key, Location)
readKeysFromCache' origFull =
  let (count, table) = BS.splitAt 8 origFull
      keysAmount :: Int64
      keysAmount = either error id $ runGet get count
  in either (const U.empty) id $ runGet (U.replicateM (fromIntegral keysAmount) getKeyLocation) table
  where
    getKeyLocation = (,) <$> get <*> ((,) <$> get <*> get)

-- Makes keys available during run, and on exit, adds new values to cache file
withCache :: forall key value a. (Ord key, MU.Unbox key, Serialize key, NFData value, Serialize value) =>
             CacheStem -> Handle -> Int -> (U.Vector (key, Location) -> (Location -> IO value) -> (key -> value -> IO ()) -> IO a) -> IO a
withCache cacheStem logHandle maxNewKeys inner = withFile (payloadFile cacheStem) ReadWriteMode $ \payloadHandle -> do
  -- Get existing keys:
  prevKeys <- readKeysFromCache cacheStem
  -- Allocate space for as many new keys as we might need (length of tasks array)
  newKeys <- MU.new maxNewKeys
  (newKeysVar, mutex) <- atomically $ (,) <$> newTMVar 0 <*> newTMVar ()
  let readValue :: Location -> IO value
      readValue (start, len) = do
        -- With the mutex, seek and read from the cache payload file:
        atomically $ takeTMVar mutex
        hSeek payloadHandle AbsoluteSeek (toInteger start)
        val <- BS.hGet payloadHandle (fromIntegral len)
        atomically $ putTMVar mutex ()
        -- Force evaluation to make sure the conversion is done now, not ages down the line:
        evaluate $ either error force $ runGet get val
      writeValue :: key -> value -> IO ()
      writeValue k v = do
        -- With the mutex, seek and write to the cache payload file:
        newPayload <- evaluate $ runPut (put v)
        n <- atomically $ takeTMVar mutex >> takeTMVar newKeysVar
        hSeek payloadHandle SeekFromEnd 0
        start <- hTell payloadHandle
        BS.hPut payloadHandle newPayload        
        MU.write newKeys n (k, (fromInteger start, fromIntegral $ BS.length newPayload))
        atomically $ putTMVar mutex () >> putTMVar newKeysVar (succ n)
  result <- inner prevKeys readValue writeValue
  -- At the end, we get all the keys together, sort them and write them all out to the index file:
  printTime logHandle "Combining keys "
  numNewKeys <- atomically $ takeTMVar newKeysVar
  let endPrevKeys = U.length prevKeys
  joinedKeys <- flip MU.unsafeGrow numNewKeys =<< U.unsafeThaw prevKeys
  mapM_ (\n -> MU.read newKeys n >>= MU.write joinedKeys (n + endPrevKeys)) [0 .. numNewKeys - 1]
  printTime logHandle "Sorting keys "
  MU.sort joinedKeys
  frozenJoinedKeys <- U.unsafeFreeze joinedKeys
  printTime logHandle "Writing index "
  withFile (indexFile cacheStem) WriteMode $ \indexHandle -> do
    BS.hPut indexHandle $ runPut $ put (fromIntegral (U.length frozenJoinedKeys) :: Int64)
    U.mapM_ (BS.hPut indexHandle . runPut . (\(k, l) -> put k >> put (fst l) >> put (snd l))) frozenJoinedKeys
  return result
  
printTime :: Handle -> String -> IO ()
printTime h msg = getCurrentTime >>= hPutStrLnFlush h . (msg ++) . show


binarySearch :: (Ord key, MU.Unbox key, MU.Unbox v) => key -> U.Vector (key, v) -> Maybe v
binarySearch tgt v = go 0 (U.length v - 1)
  where
    go imin imax
      | imax < imin = Nothing
      | otherwise = let imid = (imin + imax) `div` 2 -- Could overflow on *very* large caches
                        (k, x) = v U.! imid
                    in case compare k tgt of
                         GT -> go imin (imid - 1)
                         LT -> go (imid + 1) imax
                         EQ -> Just x

-- | A function that performs caching (between runs of the same tasks) to help when running the same analysis task
-- many times.
--
-- Imagine that you have a program where you want to some map-reduce work.  The mapping takes a long time, but you
-- are working on the reduce part.  You don't want to have to redo the mapping every time you run your program;
-- you can use this cache functionality to save the results of the mapping between program runs.  Alternatively, you
-- may want to analyse only part of your data at first (for speed) then slowly expand to the rest of the data set.
-- Caching allows you to re-use the results you have already calculated.
--
-- There are three main concepts in the type signature.  @input@ is a type containing all the information needed
-- to perform the task and produce the output.  This may involve file handles or functions or whatever.  The @key@
-- type is generally smaller, and is the smallest possible unique identifier for a corresponding output.  This might
-- be the primary key of a database record, or an input filename.  (Obviously, in some cases, @input = key@; that
-- makes life easy).  The @output@ type is the output of the task.
--
-- In order to serialise the cache to a file, both @key@ and @output@ have to be instances of @Serialize@.  To allow
-- efficient unboxing of a vector, we require an @Unbox@ instance for @key@ (contact me if you think this is too onerous),
-- and to ensure strict reading from the cache we require @NFData@ for output.
--
-- Remember that @parMapCache@ doesn't know when your cache is invalid (e.g. because you've altered the processing algorithm
-- that you are passing to this function), and will blindly use it if it finds it.  It's your responsibility to remove
-- the cache when it becomes invalid. 
parMapCache :: forall input output key m. (MonadIO m, Ord key, Show key, MU.Unbox key, NFData output, Serialize key, Serialize output) =>
  ParTaskOpts m output
  -- ^ The parallel task options for running these tasks in parallel
  -> FilePath
  -- ^ The directory in which to store the cache files (\"cache-index\" and \"cache-payload\")
  -- and the log file (\"parmap-log\").  If you have multiple distinct parMapCache tasks
  -- and you don't want them overlapping, pass a different directory for each.
  -- (This is definitely a good idea, because if your two functions have an identical
  -- serialised @key@ value, you'll be in all sorts of trouble!)
  -> (input -> key)
  -- ^ The function to map inputs to keys
  -> (input -> m output)
  -- ^ The actual function to calculate an output from an input.  Note that despite
  -- the NFData instance on output, we do not force the evaluation of output;
  -- that is left to you to do inside this function.
  -> [input]
  -- ^ The list of inputs to process
  -> m (MV.IOVector output)
  -- ^ The vector of outputs.
parMapCache opts dir getKey process inputs
       -- vOutcome is for statistics, holds CacheOutcome values
  = do vOutcome <- liftIO $ MV.new (length inputs)
       logFile <- liftIO $ openFile (dir ++ "/parmap-log") WriteMode
       let fullOpts = (ExtendedParTaskOpts opts
             ($ logFile)
             -- When we write to the results array, we also write to our outcomes array:
             (\t n outcome -> case outcome of
                 Success -> do (_, x) <- MV.read vOutcome n
                               MV.write vOutcome n (t, x)
                               return Nothing
                 TookTooLong -> do let key = getKey $ inputs !! n
                                   MV.write vOutcome n (t, CacheMissTookTooLong key)
                                   return $ Just [i|*** Killed task with key ${show key} for taking too long|]
             ))
       run <- wrapWorker opts
       results <- liftIO $ withCache (dir ++ "/cache") logFile (length inputs) $
                            \cachedKeys readValue saveResult -> run $
                               parallelTasks fullOpts (zipWith (processWithCache vOutcome cachedKeys readValue saveResult) [0..] inputs)
       -- Print all the statistics at the end:
       liftIO $ do
         outcomes <- V.unsafeFreeze vOutcome
         let hits = fstFilter isCacheHit outcomes
             missSuccesses =  fstFilter isCacheMissSuccess outcomes
         hPutStrLn logFile [i|Complete; hits: ${V.length hits}, misses: ${V.length missSuccesses}, timed out: ${V.length outcomes - V.length hits - V.length missSuccesses}|]
         hPutStrLn logFile [i|Average cache hit time: ${average hits}|]
         hPutStrLn logFile [i|Average successful task (cache miss) time: ${average missSuccesses}|]
         hPutStrLn logFile [i|Median successful task (cache miss) time: ${median missSuccesses}|]
         hPutStrLn logFile [i|Longest successful task (cache miss) time: ${maximumV missSuccesses}|]
         hPutStrLn logFile "Details of killed tasks:"
         sequence_ [hPutStrLn logFile [i|  Killed task, key: ${show k}|] | (_, CacheMissTookTooLong k) <- V.toList outcomes]
         hClose logFile
       return results
  where
    -- Looks for a given key in the cache.  If it finds it, reads the associated value and returns it.
    -- If it doesn't find it, calculates it and marks it for addition to the cache.
    processWithCache :: MV.IOVector (Double, CacheOutcome key) -> U.Vector (key, Location) -> ((Int64, Int64) -> IO output) -> (key -> output -> IO ()) -> Int -> input -> m output
    processWithCache vOutcome cachedKeys readValue saveResult n x = case binarySearch theKey cachedKeys of
      Just resultLoc -> liftIO $ do MV.write vOutcome n (0, CacheHit)
                                    readValue resultLoc
      Nothing -> do result <- process x
                    liftIO $ MV.write vOutcome n (0, CacheMissSuccess)
                    liftIO $ saveResult theKey result
                    return result
      where
        theKey = getKey x

average :: V.Vector Double -> Double
average xs = V.foldr (+) 0 xs / fromIntegral (V.length xs)

maximumV :: V.Vector Double -> Double
maximumV = V.foldr max 0

median :: V.Vector Double -> Double
median = median' . (\v -> runST (stSort v))
  where
    stSort :: V.Vector Double -> (forall s. ST s (V.Vector Double))
    stSort orig = do 
      copy <- V.thaw orig
      MU.sort copy
      V.unsafeFreeze copy
    
    median' :: V.Vector Double -> Double
    median' v
      | V.null v = 1 / 0 -- NaN
      | V.length v `mod` 2 == 1 = v V.! (V.length v `div` 2)
      | otherwise = ((v V.! (V.length v `div` 2)) + (v V.! ((V.length v `div` 2) + 1))) / 2

fstFilter :: (b -> Bool) -> V.Vector (a, b) -> V.Vector a
fstFilter f v = V.unfoldr build 0
  where
    build n
      | n >= V.length v = Nothing
      | f (snd x) = Just (fst x, succ n)
      | otherwise = build (succ n)
      where x = v V.! n
            
-- | A version of putStrLn that flushes after writing (and works in any MonadIO monad)
hPutStrLnFlush :: MonadIO m => Handle -> String -> m ()
hPutStrLnFlush h s = liftIO $ hPutStrLn h s >> hFlush h