{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Test.NoFib.Prelude.Permute (
test_permute
) where
import Control.Monad
import Data.Proxy
import Data.Typeable
import System.IO.Unsafe
import Prelude as P
import qualified Data.Set as Set
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar as S
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Test.NoFib.Base
import Data.Array.Accelerate.Test.NoFib.Config
import Data.Array.Accelerate.Test.Similar
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Tasty
import Test.Tasty.Hedgehog
test_permute :: RunN -> TestTree
test_permute runN =
testGroup "permute"
[ at (Proxy::Proxy TestInt8) $ testElt i8
, at (Proxy::Proxy TestInt16) $ testElt i16
, at (Proxy::Proxy TestInt32) $ testElt i32
, at (Proxy::Proxy TestInt64) $ testElt i64
, at (Proxy::Proxy TestWord8) $ testElt w8
, at (Proxy::Proxy TestWord16) $ testElt w16
, at (Proxy::Proxy TestWord32) $ testElt w32
, at (Proxy::Proxy TestWord64) $ testElt w64
, at (Proxy::Proxy TestHalf) $ testElt f16
, at (Proxy::Proxy TestFloat) $ testElt f32
, at (Proxy::Proxy TestDouble) $ testElt f64
]
where
testElt
:: forall a. (Similar a, P.Num a, A.Num a)
=> Gen a
-> TestTree
testElt e =
testGroup (show (typeOf (undefined :: a)))
[ testDim dim1
, testDim dim2
, testDim dim3
]
where
testDim
:: forall sh. (Shape sh, Slice sh, P.Eq sh)
=> Gen (sh:.Int)
-> TestTree
testDim sh =
testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int))))
[
testProperty "scatter->DIM1" $ test_scatter runN sh dim1 e
, testProperty "scatter->DIM2" $ test_scatter runN sh dim2 e
, testProperty "scatter->DIM3" $ test_scatter runN sh dim3 e
, testProperty "accumulate->DIM1" $ test_accumulate runN sh dim1 e
, testProperty "accumulate->DIM2" $ test_accumulate runN sh dim2 e
, testProperty "accumulate->DIM3" $ test_accumulate runN sh dim3 e
]
test_scatter
:: forall sh sh' e. (Shape sh, Shape sh', P.Eq sh', Similar e, Elt e)
=> RunN
-> Gen sh
-> Gen sh'
-> Gen e
-> Property
test_scatter runN dim dim' e =
property $ do
sh <- forAll dim
sh' <- forAll (dim' `except` \v -> S.size v P.== 0)
let
n = S.size sh
n' = S.size sh'
shfl seen i
| i P.>= n = return []
| otherwise = do
t <- Gen.choice [ return (-1)
, Gen.int (Range.linear 0 (n'-1))
]
ts <- shfl (Set.insert t seen) (i+1)
case Set.member t seen of
True -> return (S.ignore : ts)
False -> return (S.fromIndex sh' t : ts)
def <- forAll (array sh' e)
new <- forAll (array sh e)
ix <- forAll (fromList sh <$> shfl (Set.singleton (-1)) 0)
let !go = runN $ \i d v -> A.permute const d (i A.!) v
go ix def new ~~~ permuteRef const def (ix S.!) new
test_accumulate
:: (Shape sh, Shape sh', P.Eq sh', Similar e, P.Num e, A.Num e)
=> RunN
-> Gen sh
-> Gen sh'
-> Gen e
-> Property
test_accumulate runN dim dim' e =
property $ do
sh <- forAll dim
sh' <- forAll (dim' `except` \v -> S.size v P.== 0)
let
n' = S.size sh'
def = S.fromFunction sh' (const 0)
xs <- forAll (array sh e)
ix <- forAll (array sh (Gen.choice [ return S.ignore
, S.fromIndex sh' <$> Gen.int (Range.linear 0 (n'-1))
]))
let !go = runN $ \i d v -> A.permute (+) d (i A.!) v
go ix def xs ~~~ permuteRef (+) def (ix S.!) xs
permuteRef
:: (Shape sh, Shape sh', P.Eq sh', Elt e)
=> (e -> e -> e)
-> Array sh' e
-> (sh -> sh')
-> Array sh e
-> Array sh' e
permuteRef f def@(Array _ aold) p arr@(Array _ anew) =
unsafePerformIO $ do
let
sh = S.shape arr
sh' = S.shape def
n = S.size sh
go !i
| i P.>= n = return ()
| otherwise = do
let ix = S.fromIndex sh i
ix' = p ix
unless (ix' P.== S.ignore) $ do
let i' = S.toIndex sh' ix'
x <- toElt <$> unsafeReadArrayData anew i
x' <- toElt <$> unsafeReadArrayData aold i'
unsafeWriteArrayData aold i' (fromElt (f x x'))
go (i+1)
go 0
return def