{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ConstraintKinds #-} module Torch.FFI.Tests where import Foreign import Foreign.C.Types import Test.Hspec data TestFunctions state tensor real accreal = TestFunctions { _new :: state -> IO tensor , _newWithSize1d :: state -> CLLong -> IO tensor , _newWithSize2d :: state -> CLLong -> CLLong -> IO tensor , _newWithSize3d :: state -> CLLong -> CLLong -> CLLong -> IO tensor , _newWithSize4d :: state -> CLLong -> CLLong -> CLLong -> CLLong -> IO tensor , _nDimension :: state -> tensor -> IO CInt , _set1d :: state -> tensor -> CLLong -> real -> IO () , _get1d :: state -> tensor -> CLLong -> IO real , _set2d :: state -> tensor -> CLLong -> CLLong -> real -> IO () , _get2d :: state -> tensor -> CLLong -> CLLong -> IO real , _set3d :: state -> tensor -> CLLong -> CLLong -> CLLong -> real -> IO () , _get3d :: state -> tensor -> CLLong -> CLLong -> CLLong -> IO real , _set4d :: state -> tensor -> CLLong -> CLLong -> CLLong -> CLLong -> real -> IO () , _get4d :: state -> tensor -> CLLong -> CLLong -> CLLong -> CLLong -> IO real , _size :: state -> tensor -> CInt -> IO CLLong , _fill :: state -> tensor -> real -> IO () , _free :: state -> tensor -> IO () , _sumall :: state -> tensor -> IO accreal , _prodall :: state -> tensor -> IO accreal , _zero :: state -> tensor -> IO () , _dot :: Maybe (state -> tensor -> tensor -> IO accreal) , _abs :: Maybe (state -> tensor -> tensor -> IO ()) } type RealConstr n = (Num n, Show n, Eq n) signedSuite :: (RealConstr real, RealConstr accreal) => state -> TestFunctions state tensor real accreal -> Spec signedSuite s fs = do it "initializes empty tensor with 0 dimension" $ do t <- new s nDimension s t >>= (`shouldBe` 0) free s t it "1D tensor has correct dimensions and sizes" $ do t <- newWithSize1d s 10 nDimension s t >>= (`shouldBe` 1) size s t 0 >>= (`shouldBe` 10) free s t it "2D tensor has correct dimensions and sizes" $ do t <- newWithSize2d s 10 25 nDimension s t >>= (`shouldBe` 2) size s t 0 >>= (`shouldBe` 10) size s t 1 >>= (`shouldBe` 25) free s t it "3D tensor has correct dimensions and sizes" $ do t <- newWithSize3d s 10 25 5 nDimension s t >>= (`shouldBe` 3) size s t 0 >>= (`shouldBe` 10) size s t 1 >>= (`shouldBe` 25) size s t 2 >>= (`shouldBe` 5) free s t it "4D tensor has correct dimensions and sizes" $ do t <- newWithSize4d s 10 25 5 62 nDimension s t >>= (`shouldBe` 4) size s t 0 >>= (`shouldBe` 10) size s t 1 >>= (`shouldBe` 25) size s t 2 >>= (`shouldBe` 5) size s t 3 >>= (`shouldBe` 62) free s t it "Can assign and retrieve correct 1D vector values" $ do t <- newWithSize1d s 10 set1d s t 0 (20) set1d s t 1 (1) set1d s t 9 (3) get1d s t 0 >>= (`shouldBe` (20)) get1d s t 1 >>= (`shouldBe` (1)) get1d s t 9 >>= (`shouldBe` (3)) free s t it "Can assign and retrieve correct 2D vector values" $ do t <- newWithSize2d s 10 15 set2d s t 0 0 (20) set2d s t 1 5 (1) set2d s t 9 9 (3) get2d s t 0 0 >>= (`shouldBe` (20)) get2d s t 1 5 >>= (`shouldBe` (1)) get2d s t 9 9 >>= (`shouldBe` (3)) free s t it "Can assign and retrieve correct 3D vector values" $ do t <- newWithSize3d s 10 15 10 set3d s t 0 0 0 (20) set3d s t 1 5 3 (1) set3d s t 9 9 9 (3) get3d s t 0 0 0 >>= (`shouldBe` (20)) get3d s t 1 5 3 >>= (`shouldBe` (1)) get3d s t 9 9 9 >>= (`shouldBe` (3)) free s t it "Can assign and retrieve correct 4D vector values" $ do t <- newWithSize4d s 10 15 10 20 set4d s t 0 0 0 0 (20) set4d s t 1 5 3 2 (1) set4d s t 9 9 9 9 (3) get4d s t 0 0 0 0 >>= (`shouldBe` (20)) get4d s t 1 5 3 2 >>= (`shouldBe` (1)) get4d s t 9 9 9 9 >>= (`shouldBe` (3)) free s t it "Can can initialize values with the fill method" $ do t1 <- newWithSize2d s 2 2 fill s t1 3 get2d s t1 0 0 >>= (`shouldBe` (3)) free s t1 it "Can compute sum of all values" $ do t1 <- newWithSize3d s 2 2 4 fill s t1 2 sumall s t1 >>= (`shouldBe` 32) free s t1 it "Can compute product of all values" $ do t1 <- newWithSize2d s 2 2 fill s t1 2 prodall s t1 >>= (`shouldBe` 16) free s t1 case mdot of Nothing -> pure () Just dot -> describe "tests that rely on dot products" $ dotSpec s fs dot case mabs of Nothing -> pure () Just abs -> it "Can take abs of tensor values" $ do t1 <- newWithSize2d s 2 2 fill s t1 (-2) -- sequencing does not work if there is more than one shouldBe test in -- an "it" monad -- sumall s t1 >>= (`shouldBe` (-6.0)) abs s t1 t1 sumall s t1 >>= (`shouldBe` 8) free s t1 where new = _new fs newWithSize1d = _newWithSize1d fs newWithSize2d = _newWithSize2d fs newWithSize3d = _newWithSize3d fs newWithSize4d = _newWithSize4d fs nDimension = _nDimension fs set1d = _set1d fs get1d = _get1d fs set2d = _set2d fs get2d = _get2d fs set3d = _set3d fs get3d = _get3d fs set4d = _set4d fs get4d = _get4d fs size = _size fs fill = _fill fs free = _free fs sumall = _sumall fs mabs = _abs fs prodall = _prodall fs mdot = _dot fs zero = _zero fs dotSpec s fs dot = do it "Can compute correct dot product between 1D vectors" $ do t1 <- newWithSize1d s 3 t2 <- newWithSize1d s 3 fill s t1 3 fill s t2 4 let value = dot s t1 t2 value >>= (`shouldBe` 36) free s t1 free s t2 it "Can compute correct dot product between 2D tensors" $ do t1 <- newWithSize2d s 2 2 t2 <- newWithSize2d s 2 2 fill s t1 3 fill s t2 4 let value = dot s t1 t2 value >>= (`shouldBe` 48) free s t1 free s t2 it "Can zero out values" $ do t1 <- newWithSize4d s 2 2 4 3 fill s t1 3 -- let value = dot s t1 t1 -- sequencing does not work if there is more than one shouldBe test in -- an "it" monad -- value >>= (`shouldBe` (432.0)) zero s t1 dot s t1 t1 >>= (`shouldBe` 0) free s t1 it "Can compute correct dot product between 3D tensors" $ do t1 <- newWithSize3d s 2 2 4 t2 <- newWithSize3d s 2 2 4 fill s t1 3 fill s t2 4 let value = dot s t1 t2 value >>= (`shouldBe` 192) free s t1 free s t2 it "Can compute correct dot product between 4D tensors" $ do t1 <- newWithSize4d s 2 2 2 1 t2 <- newWithSize4d s 2 2 2 1 fill s t1 3 fill s t2 4 let value = dot s t1 t2 value >>= (`shouldBe` 96) free s t1 free s t2 where new = _new fs newWithSize1d = _newWithSize1d fs newWithSize2d = _newWithSize2d fs newWithSize3d = _newWithSize3d fs newWithSize4d = _newWithSize4d fs nDimension = _nDimension fs set1d = _set1d fs get1d = _get1d fs set2d = _set2d fs get2d = _get2d fs set3d = _set3d fs get3d = _get3d fs set4d = _set4d fs get4d = _get4d fs size = _size fs fill = _fill fs free = _free fs sumall = _sumall fs mabs = _abs fs prodall = _prodall fs mdot = _dot fs zero = _zero fs