{-# LANGUAGE DeriveFunctor, QuasiQuotes, Rank2Types, ScopedTypeVariables #-} -- | A module with a function to support caching the output of your parallel tasks. module Control.Concurrent.ParallelTasks.Cache (parMapCache) where import Control.Applicative ((<$>), (<*>)) import Control.Concurrent.STM (atomically) import Control.Concurrent.STM.TMVar (newTMVar, putTMVar, takeTMVar) import Control.DeepSeq (NFData, force) import Control.Exception as E(catch, evaluate, IOException) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.ST (ST, runST) import qualified Data.ByteString as BS import Data.Int (Int64) import Data.Serialize import Data.String.Here.Interpolated (i) import Data.Time.Clock (getCurrentTime) import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Unboxed.Mutable as MU import qualified Data.Vector.Mutable as MV import qualified Data.Vector as V import Data.Vector.Algorithms.Intro as MU import System.IO (Handle, IOMode(..), SeekMode(..), hClose, hSeek, hTell, openFile, withFile) import System.IO (hFlush, hPutStrLn) import Control.Concurrent.ParallelTasks.Base (ExtendedParTaskOpts(..), ParTaskOpts(..), TaskOutcome(..), parallelTasks) type Location = (Int64, Int64) -- start, length type CacheStem = String data CacheOutcome key = CacheHit | CacheMissSuccess | CacheMissTookTooLong key isCacheHit :: CacheOutcome a -> Bool isCacheHit CacheHit = True isCacheHit _ = False isCacheMissSuccess :: CacheOutcome a -> Bool isCacheMissSuccess CacheMissSuccess = True isCacheMissSuccess _ = False -- Cache index file structure: -- Int64 (number of keys) -- Then, that many lots of: -- key, Int64 (start, relative to payload), Int64 (end) -- Cache payload file structure: -- Payload indexFile :: CacheStem -> FilePath indexFile = (++ "-index") payloadFile :: CacheStem -> FilePath payloadFile = (++ "-payload") readKeysFromCache :: (U.Unbox key, Serialize key) => CacheStem -> IO (U.Vector (key, Location)) readKeysFromCache cacheStem = (readKeysFromCache' <$> BS.readFile (indexFile cacheStem)) `E.catch` (\(_e :: IOException) -> return $ U.fromList []) readKeysFromCache' :: (U.Unbox key, Serialize key) => BS.ByteString -> U.Vector (key, Location) readKeysFromCache' origFull = let (count, table) = BS.splitAt 8 origFull keysAmount :: Int64 keysAmount = either error id $ runGet get count in either (const U.empty) id $ runGet (U.replicateM (fromIntegral keysAmount) getKeyLocation) table where getKeyLocation = (,) <$> get <*> ((,) <$> get <*> get) -- Makes keys available during run, and on exit, adds new values to cache file withCache :: forall key value a. (Ord key, MU.Unbox key, Serialize key, NFData value, Serialize value) => CacheStem -> Handle -> Int -> (U.Vector (key, Location) -> (Location -> IO value) -> (key -> value -> IO ()) -> IO a) -> IO a withCache cacheStem logHandle maxNewKeys inner = withFile (payloadFile cacheStem) ReadWriteMode $ \payloadHandle -> do -- Get existing keys: prevKeys <- readKeysFromCache cacheStem -- Allocate space for as many new keys as we might need (length of tasks array) newKeys <- MU.new maxNewKeys (newKeysVar, mutex) <- atomically $ (,) <$> newTMVar 0 <*> newTMVar () let readValue :: Location -> IO value readValue (start, len) = do -- With the mutex, seek and read from the cache payload file: atomically $ takeTMVar mutex hSeek payloadHandle AbsoluteSeek (toInteger start) val <- BS.hGet payloadHandle (fromIntegral len) atomically $ putTMVar mutex () -- Force evaluation to make sure the conversion is done now, not ages down the line: evaluate $ either error force $ runGet get val writeValue :: key -> value -> IO () writeValue k v = do -- With the mutex, seek and write to the cache payload file: newPayload <- evaluate $ runPut (put v) n <- atomically $ takeTMVar mutex >> takeTMVar newKeysVar hSeek payloadHandle SeekFromEnd 0 start <- hTell payloadHandle BS.hPut payloadHandle newPayload MU.write newKeys n (k, (fromInteger start, fromIntegral $ BS.length newPayload)) atomically $ putTMVar mutex () >> putTMVar newKeysVar (succ n) result <- inner prevKeys readValue writeValue -- At the end, we get all the keys together, sort them and write them all out to the index file: printTime logHandle "Combining keys " numNewKeys <- atomically $ takeTMVar newKeysVar let endPrevKeys = U.length prevKeys joinedKeys <- flip MU.unsafeGrow numNewKeys =<< U.unsafeThaw prevKeys mapM_ (\n -> MU.read newKeys n >>= MU.write joinedKeys (n + endPrevKeys)) [0 .. numNewKeys - 1] printTime logHandle "Sorting keys " MU.sort joinedKeys frozenJoinedKeys <- U.unsafeFreeze joinedKeys printTime logHandle "Writing index " withFile (indexFile cacheStem) WriteMode $ \indexHandle -> do BS.hPut indexHandle $ runPut $ put (fromIntegral (U.length frozenJoinedKeys) :: Int64) U.mapM_ (BS.hPut indexHandle . runPut . (\(k, l) -> put k >> put (fst l) >> put (snd l))) frozenJoinedKeys return result printTime :: Handle -> String -> IO () printTime h msg = getCurrentTime >>= hPutStrLnFlush h . (msg ++) . show binarySearch :: (Ord key, MU.Unbox key, MU.Unbox v) => key -> U.Vector (key, v) -> Maybe v binarySearch tgt v = go 0 (U.length v - 1) where go imin imax | imax < imin = Nothing | otherwise = let imid = (imin + imax) `div` 2 -- Could overflow on *very* large caches (k, x) = v U.! imid in case compare k tgt of GT -> go imin (imid - 1) LT -> go (imid + 1) imax EQ -> Just x -- | A function that performs caching (between runs of the same tasks) to help when running the same analysis task -- many times. -- -- Imagine that you have a program where you want to some map-reduce work. The mapping takes a long time, but you -- are working on the reduce part. You don't want to have to redo the mapping every time you run your program; -- you can use this cache functionality to save the results of the mapping between program runs. Alternatively, you -- may want to analyse only part of your data at first (for speed) then slowly expand to the rest of the data set. -- Caching allows you to re-use the results you have already calculated. -- -- There are three main concepts in the type signature. @input@ is a type containing all the information needed -- to perform the task and produce the output. This may involve file handles or functions or whatever. The @key@ -- type is generally smaller, and is the smallest possible unique identifier for a corresponding output. This might -- be the primary key of a database record, or an input filename. (Obviously, in some cases, @input = key@; that -- makes life easy). The @output@ type is the output of the task. -- -- In order to serialise the cache to a file, both @key@ and @output@ have to be instances of @Serialize@. To allow -- efficient unboxing of a vector, we require an @Unbox@ instance for @key@ (contact me if you think this is too onerous), -- and to ensure strict reading from the cache we require @NFData@ for output. -- -- Remember that @parMapCache@ doesn't know when your cache is invalid (e.g. because you've altered the processing algorithm -- that you are passing to this function), and will blindly use it if it finds it. It's your responsibility to remove -- the cache when it becomes invalid. parMapCache :: forall input output key m. (MonadIO m, Ord key, Show key, MU.Unbox key, NFData output, Serialize key, Serialize output) => ParTaskOpts m output -- ^ The parallel task options for running these tasks in parallel -> FilePath -- ^ The directory in which to store the cache files (\"cache-index\" and \"cache-payload\") -- and the log file (\"parmap-log\"). If you have multiple distinct parMapCache tasks -- and you don't want them overlapping, pass a different directory for each. -- (This is definitely a good idea, because if your two functions have an identical -- serialised @key@ value, you'll be in all sorts of trouble!) -> (input -> key) -- ^ The function to map inputs to keys -> (input -> m output) -- ^ The actual function to calculate an output from an input. Note that despite -- the NFData instance on output, we do not force the evaluation of output; -- that is left to you to do inside this function. -> [input] -- ^ The list of inputs to process -> m (MV.IOVector output) -- ^ The vector of outputs. parMapCache opts dir getKey process inputs -- vOutcome is for statistics, holds CacheOutcome values = do vOutcome <- liftIO $ MV.new (length inputs) logFile <- liftIO $ openFile (dir ++ "/parmap-log") WriteMode let fullOpts = (ExtendedParTaskOpts opts ($ logFile) -- When we write to the results array, we also write to our outcomes array: (\t n outcome -> case outcome of Success -> do (_, x) <- MV.read vOutcome n MV.write vOutcome n (t, x) return Nothing TookTooLong -> do let key = getKey $ inputs !! n MV.write vOutcome n (t, CacheMissTookTooLong key) return $ Just [i|*** Killed task with key ${show key} for taking too long|] )) run <- wrapWorker opts results <- liftIO $ withCache (dir ++ "/cache") logFile (length inputs) $ \cachedKeys readValue saveResult -> run $ parallelTasks fullOpts (zipWith (processWithCache vOutcome cachedKeys readValue saveResult) [0..] inputs) -- Print all the statistics at the end: liftIO $ do outcomes <- V.unsafeFreeze vOutcome let hits = fstFilter isCacheHit outcomes missSuccesses = fstFilter isCacheMissSuccess outcomes hPutStrLn logFile [i|Complete; hits: ${V.length hits}, misses: ${V.length missSuccesses}, timed out: ${V.length outcomes - V.length hits - V.length missSuccesses}|] hPutStrLn logFile [i|Average cache hit time: ${average hits}|] hPutStrLn logFile [i|Average successful task (cache miss) time: ${average missSuccesses}|] hPutStrLn logFile [i|Median successful task (cache miss) time: ${median missSuccesses}|] hPutStrLn logFile [i|Longest successful task (cache miss) time: ${maximumV missSuccesses}|] hPutStrLn logFile "Details of killed tasks:" sequence_ [hPutStrLn logFile [i| Killed task, key: ${show k}|] | (_, CacheMissTookTooLong k) <- V.toList outcomes] hClose logFile return results where -- Looks for a given key in the cache. If it finds it, reads the associated value and returns it. -- If it doesn't find it, calculates it and marks it for addition to the cache. processWithCache :: MV.IOVector (Double, CacheOutcome key) -> U.Vector (key, Location) -> ((Int64, Int64) -> IO output) -> (key -> output -> IO ()) -> Int -> input -> m output processWithCache vOutcome cachedKeys readValue saveResult n x = case binarySearch theKey cachedKeys of Just resultLoc -> liftIO $ do MV.write vOutcome n (0, CacheHit) readValue resultLoc Nothing -> do result <- process x liftIO $ MV.write vOutcome n (0, CacheMissSuccess) liftIO $ saveResult theKey result return result where theKey = getKey x average :: V.Vector Double -> Double average xs = V.foldr (+) 0 xs / fromIntegral (V.length xs) maximumV :: V.Vector Double -> Double maximumV = V.foldr max 0 median :: V.Vector Double -> Double median = median' . (\v -> runST (stSort v)) where stSort :: V.Vector Double -> (forall s. ST s (V.Vector Double)) stSort orig = do copy <- V.thaw orig MU.sort copy V.unsafeFreeze copy median' :: V.Vector Double -> Double median' v | V.null v = 1 / 0 -- NaN | V.length v `mod` 2 == 1 = v V.! (V.length v `div` 2) | otherwise = ((v V.! (V.length v `div` 2)) + (v V.! ((V.length v `div` 2) + 1))) / 2 fstFilter :: (b -> Bool) -> V.Vector (a, b) -> V.Vector a fstFilter f v = V.unfoldr build 0 where build n | n >= V.length v = Nothing | f (snd x) = Just (fst x, succ n) | otherwise = build (succ n) where x = v V.! n -- | A version of putStrLn that flushes after writing (and works in any MonadIO monad) hPutStrLnFlush :: MonadIO m => Handle -> String -> m () hPutStrLnFlush h s = liftIO $ hPutStrLn h s >> hFlush h