{-# LANGUAGE OverloadedLists #-} module Main (main) where import Control.Monad.IO.Class (liftIO) import qualified Data.Vector.Storable as V import TensorFlow.Core ( unScalar , render , run_ , runSession , run , withControlDependencies) import qualified TensorFlow.Ops as Ops import TensorFlow.Variable ( readValue , initializedVariable , assign , assignAdd , variable ) import Test.Framework (defaultMain, Test) import Test.Framework.Providers.HUnit (testCase) import Test.HUnit ((@=?)) main :: IO () main = defaultMain [ testInitializedVariable , testInitializedVariableShape , testDependency , testRereadRef , testAssignAdd ] testInitializedVariable :: Test testInitializedVariable = testCase "testInitializedVariable" $ runSession $ do (formula, reset) <- do v <- initializedVariable 42 r <- assign v 24 return (1 + readValue v, r) result <- run formula liftIO $ 43 @=? (unScalar result :: Float) run_ reset -- Updates v to a different value rerunResult <- run formula liftIO $ 25 @=? (unScalar rerunResult :: Float) testInitializedVariableShape :: Test testInitializedVariableShape = testCase "testInitializedVariableShape" $ runSession $ do vector <- initializedVariable (Ops.constant [1] [42 :: Float]) result <- run (readValue vector) liftIO $ [42] @=? (result :: V.Vector Float) testDependency :: Test testDependency = testCase "testDependency" $ runSession $ do v <- variable [] a <- assign v 24 r <- withControlDependencies a $ render (readValue v + 18) result <- run r liftIO $ (42 :: Float) @=? unScalar result -- | See https://github.com/tensorflow/haskell/issues/92. -- Even though we're not explicitly evaluating `f0` until the end, -- it should hold the earlier value of the variable. testRereadRef :: Test testRereadRef = testCase "testReRunAssign" $ runSession $ do w <- initializedVariable 0 f0 <- run (readValue w) run_ =<< assign w (Ops.scalar (0.1 :: Float)) f1 <- run (readValue w) liftIO $ (0.0, 0.1) @=? (unScalar f0, unScalar f1) testAssignAdd :: Test testAssignAdd = testCase "testAssignAdd" $ runSession $ do w <- initializedVariable 42 run_ =<< assignAdd w 17 f1 <- run (readValue w) liftIO $ (42 + 17 :: Float) @=? unScalar f1