module Data.Array.Accelerate.CUDA.Persistent (
KernelTable, KernelKey, KernelEntry(..),
new, lookup, insert, persist,
module_finalizer,
) where
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.CUDA.Context
import Data.Array.Accelerate.CUDA.FullList ( FullList )
import qualified Data.Array.Accelerate.CUDA.Debug as D
import qualified Data.Array.Accelerate.CUDA.FullList as FL
import Prelude hiding ( lookup )
import Numeric
import Data.Char
import System.IO
import System.FilePath
import System.Directory
import System.IO.Error
import System.Mem.Weak
import Control.Applicative
import Control.Concurrent
import Control.Exception
import Control.Monad.Trans
import Data.Version
import Data.Binary
import Data.Hashable
import Data.Binary.Get
import Data.ByteString ( ByteString )
import Data.ByteString.Internal ( w2c )
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.ByteString.Lazy.Internal as BL
import qualified Data.HashTable.IO as HT
import qualified Foreign.CUDA.Driver as CUDA
import Paths_accelerate_cuda
instance Hashable CUDA.Compute where
hashWithSalt salt (CUDA.Compute major minor)
= salt `hashWithSalt` major `hashWithSalt` minor
instance Binary CUDA.Compute where
put (CUDA.Compute major minor) = put major >> put minor
get = CUDA.Compute <$> get <*> get
type HashTable key val = HT.BasicHashTable key val
data KernelTable = KT !ProgramCache
!PersistentCache
new :: IO KernelTable
new = do
message "initialise kernel table"
cacheDir <- cacheDirectory
createDirectoryIfMissing True cacheDir
local <- newMVar =<< HT.new
persistent <- restore (cacheDir </> "persistent.db")
return $! KT local persistent
lookup :: Context -> KernelTable -> KernelKey -> IO (Maybe KernelEntry)
lookup context (KT !kt_ref !pt_ref) !key = withMVar kt_ref $ \kt -> do
v1 <- HT.lookup kt key
case v1 of
Just _ -> return v1
Nothing -> withMVar pt_ref $ \pt -> do
v2 <- HT.lookup pt key
case v2 of
Nothing -> return Nothing
Just () -> do
message "found/persistent"
cubin <- (</>) <$> cacheDirectory <*> pure (cacheFilePath key)
bin <- BS.readFile cubin
!mdl <- CUDA.loadData bin
let obj = KernelObject bin (FL.singleton (deviceContext context) mdl)
addFinalizer mdl (module_finalizer (weakContext context) key mdl)
HT.insert kt key obj
return $! Just obj
insert :: KernelTable -> KernelKey -> KernelEntry -> IO ()
insert (KT !kt_ref !_) !key !val = withMVar kt_ref $ \kt -> HT.insert kt key val
module_finalizer :: Weak CUDA.Context -> KernelKey -> CUDA.Module -> IO ()
module_finalizer weak_ctx key mdl = do
mc <- deRefWeak weak_ctx
case mc of
Nothing -> D.message D.dump_gc ("gc: finalise module/dead context: " ++ cacheFilePath key)
Just ctx -> D.message D.dump_gc ("gc: finalise module: " ++ cacheFilePath key)
>> bracket_ (CUDA.push ctx) CUDA.pop (CUDA.unload mdl)
type ProgramCache = MVar ( HashTable KernelKey KernelEntry )
type KernelKey = (CUDA.Compute, ByteString)
data KernelEntry
= CompileProcess !FilePath
!(MVar ())
| KernelObject !ByteString
!(FullList CUDA.Context CUDA.Module)
type PersistentCache = MVar ( HashTable KernelKey () )
cacheDirectory :: IO FilePath
cacheDirectory = do
home <- getAppUserDataDirectory "accelerate"
return $ home </> "accelerate-cuda-" ++ showVersion version </> "cache"
cacheFilePath :: KernelKey -> FilePath
cacheFilePath (cap, key) =
show cap </> zEncodeString (BS.foldl (flip (showLitChar . w2c)) [] key)
type EncodedString = String
zEncodeString :: String -> EncodedString
zEncodeString [] = []
zEncodeString (h:rest) = encode_digit h ++ go rest
where
go [] = []
go (c:cs) = encode_ch c ++ go cs
unencodedChar :: Char -> Bool
unencodedChar 'z' = False
unencodedChar 'Z' = False
unencodedChar c = isAlphaNum c
encode_digit :: Char -> EncodedString
encode_digit c | isDigit c = encode_as_unicode_char c
| otherwise = encode_ch c
encode_ch :: Char -> EncodedString
encode_ch c | unencodedChar c = [c]
encode_ch '(' = "ZL"
encode_ch ')' = "ZR"
encode_ch '[' = "ZM"
encode_ch ']' = "ZN"
encode_ch ':' = "ZC"
encode_ch 'Z' = "ZZ"
encode_ch 'z' = "zz"
encode_ch '&' = "za"
encode_ch '|' = "zb"
encode_ch '^' = "zc"
encode_ch '$' = "zd"
encode_ch '=' = "ze"
encode_ch '>' = "zg"
encode_ch '#' = "zh"
encode_ch '.' = "zi"
encode_ch '<' = "zl"
encode_ch '-' = "zm"
encode_ch '!' = "zn"
encode_ch '+' = "zp"
encode_ch '\'' = "zq"
encode_ch '\\' = "zr"
encode_ch '/' = "zs"
encode_ch '*' = "zt"
encode_ch '_' = "zu"
encode_ch '%' = "zv"
encode_ch c = encode_as_unicode_char c
encode_as_unicode_char :: Char -> EncodedString
encode_as_unicode_char c
= 'z'
: if isDigit (head hex_str) then hex_str
else '0':hex_str
where
hex_str = showHex (ord c) "U"
getMany :: Binary a => Int -> Get [a]
getMany n = go n []
where
go 0 xs = return xs
go i xs = do
x <- get
go (i1) (x:xs)
restore :: FilePath -> IO PersistentCache
restore !db = do
D.when D.flush_cache $ do
message "deleting persistent cache"
cacheDir <- cacheDirectory
removeDirectoryRecursive cacheDir
createDirectoryIfMissing True cacheDir
exists <- doesFileExist db
pt <- case exists of
False -> encodeFile db (0::Int) >> HT.new
True -> do
store <- BL.readFile db
let (n, rest) = setup (runGetIncremental get) store
setup (Done s _ r) lbs = (r, BL.Chunk s lbs)
setup (Partial k) lbs = setup (k (takeHeadChunk lbs)) (dropHeadChunk lbs)
setup (Fail _ p msg) _ = $internalError "restore" $ show p ++ ": " ++ msg
takeHeadChunk (BL.Chunk h _) = Just h
takeHeadChunk _ = Nothing
dropHeadChunk (BL.Chunk _ t) = t
dropHeadChunk _ = BL.empty
pt <- HT.newSized n
let go [] = return ()
go (!k:xs) = HT.insert pt k () >> go xs
message $ "persist/restore: " ++ shows n " entries"
go (runGet (getMany n) rest)
evaluate pt
newMVar pt
persist :: KernelTable -> FilePath -> KernelKey -> IO ()
persist (KT !_ !pt_ref) !cubin !key = withMVar pt_ref $ \_ -> do
cacheDir <- cacheDirectory
let db = cacheDir </> "persistent.db"
cacheFile = cacheDir </> cacheFilePath key
message $ "persist/save: " ++ cacheFile
createDirectoryIfMissing True (dropFileName cacheFile)
renameFile cubin cacheFile
`catchIOError` \_ -> do
copyFile cubin cacheFile
removeFile cubin
withBinaryFile db ReadWriteMode $ \h -> do
n <- runGet (get :: Get Int) `fmap` BL.hGet h 8
hSeek h AbsoluteSeek 0
BL.hPut h (encode (n+1))
hSeek h SeekFromEnd 0
BL.hPut h (encode key)
message :: MonadIO m => String -> m ()
message msg = trace msg $ return ()
trace :: MonadIO m => String -> m a -> m a
trace msg next = D.message D.dump_cc ("cc: " ++ msg) >> next