{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Test.NoFib.Spectral.SMVM (
test_smvm,
) where
import Data.Proxy
import Data.Typeable
import Prelude as P
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar as S
import Data.Array.Accelerate.Test.NoFib.Base
import Data.Array.Accelerate.Test.NoFib.Config
import Data.Array.Accelerate.Test.Similar
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Tasty
import Test.Tasty.Hedgehog
test_smvm :: RunN -> TestTree
test_smvm runN =
testGroup "smvm"
[ at (Proxy::Proxy TestHalf) $ testElt f16
, at (Proxy::Proxy TestFloat) $ testElt f32
, at (Proxy::Proxy TestDouble) $ testElt f64
]
where
testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a)
=> Gen a
-> TestTree
testElt e =
testProperty (show (typeOf (undefined :: a))) $ test_smvm' runN e
test_smvm' :: (A.Num e, P.Num e, Similar e) => RunN -> Gen e -> Property
test_smvm' runN e =
property $ do
(smat, cols) <- forAll (sparseMatrix e)
vec <- forAll (array (Z:.cols) e)
let !go = runN smvm in go smat vec ~~~ smvmRef smat vec
sparseMatrix :: Elt e => Gen e -> Gen (SparseMatrix e, Int)
sparseMatrix e = do
rows <- Gen.int (Range.linear 1 256)
cols <- Gen.int (Range.linear 1 256)
seg <- array (Z:.rows) (Gen.int (Range.linear 0 cols))
let nnz = P.sum (S.toList seg)
smat <- array (Z:.nnz) ((,) <$> Gen.int (Range.linear 0 (cols-1)) <*> e)
return ((seg,smat), cols)
type SparseVector e = Vector (Int, e)
type SparseMatrix e = (Segments Int, SparseVector e)
smvm :: A.Num a => Acc (SparseMatrix a) -> Acc (Vector a) -> Acc (Vector a)
smvm smat vec
= let (segd, svec) = unlift smat
(inds, vals) = A.unzip svec
vecVals = A.gather inds vec
products = A.zipWith (*) vecVals vals
in
foldSeg (+) 0 products segd
smvmRef :: (Elt a, P.Num a) => SparseMatrix a -> Vector a -> Vector a
smvmRef (segd, smat) vec =
fromList (S.shape segd)
[ P.sum [ val * indexArray vec (Z :. i) | (i,val) <- row ]
| row <- splitPlaces (toList segd) (toList smat) ]