{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Prelude.Permute -- Copyright : [2009..2017] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Test.NoFib.Prelude.Permute ( test_permute ) where import Control.Monad import Data.Proxy import Data.Typeable import System.IO.Unsafe import Prelude as P import qualified Data.Set as Set import Data.Array.Accelerate as A import Data.Array.Accelerate.Array.Sugar as S import Data.Array.Accelerate.Array.Data import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Tasty import Test.Tasty.Hedgehog test_permute :: RunN -> TestTree test_permute runN = testGroup "permute" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestHalf) $ testElt f16 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (Similar a, P.Num a, A.Num a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "scatter->DIM1" $ test_scatter runN sh dim1 e , testProperty "scatter->DIM2" $ test_scatter runN sh dim2 e , testProperty "scatter->DIM3" $ test_scatter runN sh dim3 e , testProperty "accumulate->DIM1" $ test_accumulate runN sh dim1 e , testProperty "accumulate->DIM2" $ test_accumulate runN sh dim2 e , testProperty "accumulate->DIM3" $ test_accumulate runN sh dim3 e ] test_scatter :: forall sh sh' e. (Shape sh, Shape sh', P.Eq sh', Similar e, Elt e) => RunN -> Gen sh -> Gen sh' -> Gen e -> Property test_scatter runN dim dim' e = property $ do sh <- forAll dim sh' <- forAll (dim' `except` \v -> S.size v P.== 0) let n = S.size sh n' = S.size sh' -- shfl seen i | i P.>= n = return [] | otherwise = do t <- Gen.choice [ return (-1) , Gen.int (Range.linear 0 (n'-1)) ] ts <- shfl (Set.insert t seen) (i+1) -- case Set.member t seen of True -> return (S.ignore : ts) False -> return (S.fromIndex sh' t : ts) -- def <- forAll (array sh' e) new <- forAll (array sh e) ix <- forAll (fromList sh <$> shfl (Set.singleton (-1)) 0) -- let !go = runN $ \i d v -> A.permute const d (i A.!) v go ix def new ~~~ permuteRef const def (ix S.!) new test_accumulate :: (Shape sh, Shape sh', P.Eq sh', Similar e, P.Num e, A.Num e) => RunN -> Gen sh -> Gen sh' -> Gen e -> Property test_accumulate runN dim dim' e = property $ do sh <- forAll dim sh' <- forAll (dim' `except` \v -> S.size v P.== 0) let n' = S.size sh' def = S.fromFunction sh' (const 0) -- xs <- forAll (array sh e) ix <- forAll (array sh (Gen.choice [ return S.ignore , S.fromIndex sh' <$> Gen.int (Range.linear 0 (n'-1)) ])) let !go = runN $ \i d v -> A.permute (+) d (i A.!) v go ix def xs ~~~ permuteRef (+) def (ix S.!) xs permuteRef :: (Shape sh, Shape sh', P.Eq sh', Elt e) => (e -> e -> e) -> Array sh' e -> (sh -> sh') -> Array sh e -> Array sh' e permuteRef f def@(Array _ aold) p arr@(Array _ anew) = unsafePerformIO $ do let sh = S.shape arr sh' = S.shape def n = S.size sh -- go !i | i P.>= n = return () | otherwise = do let ix = S.fromIndex sh i ix' = p ix -- unless (ix' P.== S.ignore) $ do let i' = S.toIndex sh' ix' x <- toElt <$> unsafeReadArrayData anew i x' <- toElt <$> unsafeReadArrayData aold i' unsafeWriteArrayData aold i' (fromElt (f x x')) -- go (i+1) -- go 0 return def