{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ConstraintKinds #-} module Torch.FFI.TestsNN where import Foreign import Foreign.C.Types import Test.Hspec type NNNum r = (Fractional r, Show r, Eq r) data NNTestSuite state ten real accreal gen = NNTestSuite { _newWithSize1d :: state -> CLLong -> IO ten , _newWithSize2d :: state -> CLLong -> CLLong -> IO ten , _newGen :: IO gen , _normal :: Either (state -> ten -> gen -> CDouble -> CDouble -> IO ()) (state -> ten -> CDouble -> CDouble -> IO ()) , _fill :: state -> ten -> real -> IO () , _sumall :: state -> ten -> IO accreal , _free :: state -> ten -> IO () , _nnAbsUpdateOutput :: state -> ten -> ten -> IO () , _nnHSUpdateOutput :: Maybe (state -> ten -> ten -> CDouble -> IO ()) , _nnL1UpdateOutput :: state -> ten -> ten -> IO () , _nnRReLUUpdateOutput :: state -> ten -> ten -> ten -> CDouble -> CDouble -> CBool -> CBool -> gen -> IO () } testSuite :: (NNNum real, NNNum accreal) => state -> NNTestSuite state ten real accreal gen -> Spec testSuite s fs = do it "Abs test" $ do t1 <- newWithSize2d s 2 2 fill s t1 (-3) nnAbsUpdateOutput s t1 t1 sumall s t1 >>= (`shouldBe` 12.0) free s t1 case mnnHSUpdateOutput of Nothing -> pure () Just nnHSUpdateOutput -> it "HardShrink test" $ do t1 <- newWithSize2d s 2 2 t2 <- newWithSize2d s 2 2 fill s t2 4 fill s t1 4 nnHSUpdateOutput s t1 t1 100.0 sumall s t1 >>= (`shouldBe` 0.0) nnHSUpdateOutput s t2 t2 1.0 sumall s t2 >>= (`shouldBe` 16.0) free s t1 free s t2 it "L1Cost_updateOutput" $ do t1 <- newWithSize1d s 1 fill s t1 3 nnL1UpdateOutput s t1 t1 sumall s t1 >>= (`shouldBe` 3.0) free s t1 it "RReLU_updateOutput" $ do t1 <- newWithSize1d s 100 t2 <- newWithSize1d s 100 fill s t2 0.5 g <- newGen case enormal of (Left normal) -> normal s t1 g 0 1 (Right normal) -> normal s t1 0 1 nnRReLUUpdateOutput s t2 t2 t1 0.0 15.0 1 1 g sumall s t2 >>= (`shouldBe` 50.0) free s t1 free s t2 where newWithSize1d = _newWithSize1d fs newWithSize2d = _newWithSize2d fs newGen = _newGen fs enormal = _normal fs fill = _fill fs sumall = _sumall fs free = _free fs nnAbsUpdateOutput = _nnAbsUpdateOutput fs mnnHSUpdateOutput = _nnHSUpdateOutput fs nnL1UpdateOutput = _nnL1UpdateOutput fs nnRReLUUpdateOutput = _nnRReLUUpdateOutput fs