-- Disable full-laziness to keep ghc from optimizing most of the benchmark away. {-# OPTIONS_GHC -fno-full-laziness #-} import Control.DeepSeq (NFData(rnf)) import Control.Exception (evaluate) import Control.Monad.IO.Class (liftIO) import Criterion.Main (defaultMain, bgroup, bench) import Criterion.Types (Benchmarkable(..)) import qualified Data.Vector.Storable as S import qualified TensorFlow.Core as TF import qualified TensorFlow.Ops as TF -- | Create 'Benchmarkable' for 'TF.Session'. -- -- The entire benchmark will be run in a single tensorflow session. The -- 'TF.Session' argument will be run once and then its result will be run N -- times. nfSession :: NFData b => TF.Session (a -> TF.Session b) -> a -> Benchmarkable nfSession init x = Benchmarkable $ \m -> TF.runSession $ do f <- init -- Can't use replicateM because n is Int64. let go n | n <= 0 = return () | otherwise = f x >>= liftIO . evaluate . rnf >> go (n-1) go m -- | Benchmark feeding and fetching a vector. feedFetchBenchmark :: TF.Session (S.Vector Float -> TF.Session (S.Vector Float)) feedFetchBenchmark = do input <- TF.build (TF.placeholder (TF.Shape [-1])) output <- TF.build (TF.render (TF.identity input)) return $ \v -> do let shape = TF.Shape [fromIntegral (S.length v)] inputData = TF.encodeTensorData shape v feeds = [TF.feed input inputData] TF.runWithFeeds feeds output main :: IO () main = defaultMain [ bgroup "feedFetch" [ bench "4 byte" $ nfSession feedFetchBenchmark (S.replicate 1 0) , bench "4 KiB" $ nfSession feedFetchBenchmark (S.replicate 1024 0) , bench "4 MiB" $ nfSession feedFetchBenchmark (S.replicate (1024^2) 0) ] ]