{-# LANGUAGE BangPatterns #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Test.NoFib.Prelude.Scan -- Copyright : [2009..2017] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.Test.NoFib.Prelude.Scan ( test_scanl, test_scanlSeg, test_scanl1, test_scanl1Seg, test_scanl', test_scanl'Seg, test_scanr, test_scanrSeg, test_scanr1, test_scanr1Seg, test_scanr', test_scanr'Seg, ) where import Data.Proxy import Data.Typeable import Prelude as P import Data.Array.Accelerate as A import Data.Array.Accelerate.Array.Sugar as S import Data.Array.Accelerate.Test.NoFib.Base import Data.Array.Accelerate.Test.NoFib.Config import Data.Array.Accelerate.Test.Similar import Hedgehog import qualified Hedgehog.Gen as Gen import qualified Hedgehog.Range as Range import Test.Tasty import Test.Tasty.Hedgehog test_scanl :: RunN -> TestTree test_scanl runN = testGroup "scanl" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestHalf) $ testElt f16 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanl_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanl_sum runN sh e e , testProperty "non-commutative" $ test_scanl_interval runN sh e ] test_scanl1 :: RunN -> TestTree test_scanl1 runN = testGroup "scanl1" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanl1_sum runN sh e , testProperty "non-commutative" $ test_scanl1_interval runN sh e ] test_scanl' :: RunN -> TestTree test_scanl' runN = testGroup "scanl'" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanl'_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanl'_sum runN sh e e , testProperty "non-commutative" $ test_scanl'_interval runN sh e ] test_scanr :: RunN -> TestTree test_scanr runN = testGroup "scanr" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanr_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanr_sum runN sh e e , testProperty "non-commutative" $ test_scanr_interval runN sh e ] test_scanr1 :: RunN -> TestTree test_scanr1 runN = testGroup "scanr1" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanr1_sum runN sh e , testProperty "non-commutative" $ test_scanr1_interval runN sh e ] test_scanr' :: RunN -> TestTree test_scanr' runN = testGroup "scanr'" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanr'_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanr'_sum runN sh e e , testProperty "non-commutative" $ test_scanr'_interval runN sh e ] test_scanlSeg :: RunN -> TestTree test_scanlSeg runN = testGroup "scanlSeg" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanlSeg_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanlSeg_sum runN sh e e ] test_scanl1Seg :: RunN -> TestTree test_scanl1Seg runN = testGroup "scanl1Seg" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanl1Seg_sum runN sh e ] test_scanl'Seg :: RunN -> TestTree test_scanl'Seg runN = testGroup "scanl'Seg" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanl'Seg_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanl'Seg_sum runN sh e e ] test_scanrSeg :: RunN -> TestTree test_scanrSeg runN = testGroup "scanrSeg" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanrSeg_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanrSeg_sum runN sh e e ] test_scanr1Seg :: RunN -> TestTree test_scanr1Seg runN = testGroup "scanr1Seg" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanr1Seg_sum runN sh e ] test_scanr'Seg :: RunN -> TestTree test_scanr'Seg runN = testGroup "scanr'Seg" [ at (Proxy::Proxy TestInt8) $ testElt i8 , at (Proxy::Proxy TestInt16) $ testElt i16 , at (Proxy::Proxy TestInt32) $ testElt i32 , at (Proxy::Proxy TestInt64) $ testElt i64 , at (Proxy::Proxy TestWord8) $ testElt w8 , at (Proxy::Proxy TestWord16) $ testElt w16 , at (Proxy::Proxy TestWord32) $ testElt w32 , at (Proxy::Proxy TestWord64) $ testElt w64 , at (Proxy::Proxy TestFloat) $ testElt f32 , at (Proxy::Proxy TestDouble) $ testElt f64 ] where testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a) => Gen a -> TestTree testElt e = testGroup (show (typeOf (undefined :: a))) [ testDim dim1 , testDim dim2 , testDim dim3 ] where testDim :: forall sh. (Shape sh, Slice sh, P.Eq sh) => Gen (sh:.Int) -> TestTree testDim sh = testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int)))) [ testProperty "sum" $ test_scanr'Seg_sum runN sh (return 0) e , testProperty "non-neutral sum" $ test_scanr'Seg_sum runN sh e e ] scalar :: Elt e => e -> Scalar e scalar x = fromFunction Z (const x) test_scanl_sum :: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanl_sum runN dim z e = property $ do x <- forAll z sh <- forAll dim arr <- forAll (array sh e) let !go = runN (\v -> A.scanl (+) (the v)) in go (scalar x) arr ~~~ scanlRef (+) x arr test_scanl1_sum :: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanl1_sum runN dim e = property $ do sh <- forAll (dim `except` \(_:.n) -> n P.== 0) arr <- forAll (array sh e) let !go = runN (A.scanl1 (+)) in go arr ~~~ scanl1Ref (+) arr test_scanl'_sum :: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanl'_sum runN dim z e = property $ do x <- forAll z sh <- forAll dim arr <- forAll (array sh e) let !go = runN (\v -> A.scanl' (+) (the v)) in go (scalar x) arr ~~~ scanl'Ref (+) x arr test_scanr_sum :: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanr_sum runN dim z e = property $ do x <- forAll z sh <- forAll dim arr <- forAll (array sh e) let !go = runN (\v -> A.scanr (+) (the v)) in go (scalar x) arr ~~~ scanrRef (+) x arr test_scanr1_sum :: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanr1_sum runN dim e = property $ do sh <- forAll (dim `except` \(_:.n) -> n P.== 0) arr <- forAll (array sh e) let !go = runN (A.scanr1 (+)) in go arr ~~~ scanr1Ref (+) arr test_scanr'_sum :: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanr'_sum runN dim z e = property $ do x <- forAll z sh <- forAll dim arr <- forAll (array sh e) let !go = runN (\v -> A.scanr' (+) (the v)) in go (scalar x) arr ~~~ scanr'Ref (+) x arr test_scanl_interval :: (Shape sh, Similar e, P.Eq sh, P.Eq e, P.Num e, A.Eq e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanl_interval runN dim e = property $ do sh :. n <- forAll (dim `except` \(_:.n) -> n P.== 0) let arr = intervalArray sh n e let !go = runN (A.scanl iappend (constant one)) in go arr ~~~ scanlRef iappendRef one arr test_scanl1_interval :: (Shape sh, Similar e, P.Eq sh, P.Eq e, P.Num e, A.Eq e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanl1_interval runN dim e = property $ do sh :. n <- forAll (dim `except` \v -> S.size v P.== 0) let arr = intervalArray sh n e let !go = runN (A.scanl1 iappend) in go arr ~~~ scanl1Ref iappendRef arr test_scanl'_interval :: (Shape sh, Similar e, P.Eq sh, P.Eq e, P.Num e, A.Eq e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanl'_interval runN dim e = property $ do sh :. n <- forAll (dim `except` \(_:.n) -> n P.== 0) let arr = intervalArray sh n e let !go = runN (A.scanl' iappend (constant one)) in go arr ~~~ scanl'Ref iappendRef one arr test_scanr_interval :: (Shape sh, Similar e, P.Eq sh, P.Eq e, P.Num e, A.Eq e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanr_interval runN dim e = property $ do sh :. n <- forAll (dim `except` \(_:.n) -> n P.== 0) let arr = intervalArray sh n e let !go = runN (A.scanr iappend (constant one)) in go arr ~~~ scanrRef iappendRef one arr test_scanr1_interval :: (Shape sh, Similar e, P.Eq sh, P.Eq e, P.Num e, A.Eq e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanr1_interval runN dim e = property $ do sh :. n <- forAll (dim `except` \(_:.n) -> n P.== 0) let arr = intervalArray sh n e let !go = runN (A.scanr1 iappend) in go arr ~~~ scanr1Ref iappendRef arr test_scanr'_interval :: (Shape sh, Similar e, P.Eq sh, P.Eq e, P.Num e, A.Eq e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanr'_interval runN dim e = property $ do sh :. n <- forAll (dim `except` \(_:.n) -> n P.== 0) let arr = intervalArray sh n e let !go = runN (A.scanr' iappend (constant one)) in go arr ~~~ scanr'Ref iappendRef one arr test_scanlSeg_sum :: forall sh e. (Shape sh, Slice sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanlSeg_sum runN dim z e = property $ do x <- forAll z sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- return (P.min n1 n2) -- don't generate too many segments seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank (undefined::sh)))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanlSeg (+) (the v)) in go (scalar x) arr seg ~~~ scanlSegRef (+) x arr seg test_scanl1Seg_sum :: forall sh e. (Shape sh, Slice sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanl1Seg_sum runN dim e = property $ do sh:.n1 <- forAll (dim `except` \v -> S.size v P.== 0) n2 <- forAll (Gen.int (Range.linear 1 64)) n <- return (P.min n1 n2) -- don't generate too many segments seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank (undefined::sh)))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.scanl1Seg (+)) in go arr seg ~~~ scanl1SegRef (+) arr seg test_scanl'Seg_sum :: forall sh e. (Shape sh, Slice sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanl'Seg_sum runN dim z e = property $ do x <- forAll z sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- return (P.min n1 n2) -- don't generate too many segments seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank (undefined::sh)))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanl'Seg (+) (the v)) in go (scalar x) arr seg ~~~ scanl'SegRef (+) x arr seg test_scanrSeg_sum :: forall sh e. (Shape sh, Slice sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanrSeg_sum runN dim z e = property $ do x <- forAll z sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- return (P.min n1 n2) -- don't generate too many segments seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank (undefined::sh)))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanrSeg (+) (the v)) in go (scalar x) arr seg ~~~ scanrSegRef (+) x arr seg test_scanr1Seg_sum :: forall sh e. (Shape sh, Slice sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Property test_scanr1Seg_sum runN dim e = property $ do sh:.n1 <- forAll (dim `except` \v -> S.size v P.== 0) n2 <- forAll (Gen.int (Range.linear 1 64)) n <- return (P.min n1 n2) -- don't generate too many segments seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank (undefined::sh)))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (A.scanr1Seg (+)) in go arr seg ~~~ scanr1SegRef (+) arr seg test_scanr'Seg_sum :: forall sh e. (Shape sh, Slice sh, Similar e, P.Eq sh, P.Num e, A.Num e) => RunN -> Gen (sh:.Int) -> Gen e -> Gen e -> Property test_scanr'Seg_sum runN dim z e = property $ do x <- forAll z sh:.n1 <- forAll dim n2 <- forAll (Gen.int (Range.linear 0 64)) n <- return (P.min n1 n2) -- don't generate too many segments seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank (undefined::sh)))))) arr <- forAll (array (sh:.P.sum (toList seg)) e) let !go = runN (\v -> A.scanr'Seg (+) (the v)) in go (scalar x) arr seg ~~~ scanr'SegRef (+) x arr seg -- Interval of summations monoid -- one, top :: P.Num e => (e,e) one = (-1,-1) top = (-2,-2) iappendRef :: (P.Num e, P.Eq e) => (e,e) -> (e,e) -> (e,e) iappendRef x y | x P.== one = y | y P.== one = x | x P.== top P.|| y P.== top = top iappendRef (x1,x2) (y1,y2) | x2 + 1 P.== y1 = (x1,y2) | otherwise = top iappend :: forall e. (A.Eq e, A.Num e, P.Num e) => Exp (e,e) -> Exp (e,e) -> Exp (e,e) iappend x y = x A.== constant one ? ( y , y A.== constant one ? ( x , x A.== constant top ? ( constant top -- A.|| y A.== constant top; see AccelerateHS/accelerate#364 , let (x1,x2) = unlift x :: (Exp e, Exp e) (y1,y2) = unlift y :: (Exp e, Exp e) in x2 + 1 A.== y1 ? ( lift (x1,y2) , constant top ) ))) intervalArray :: (Shape sh, Elt e, P.Num e) => sh -> Int -> proxy e -> Array (sh:.Int) (e,e) intervalArray sh n _ = fromFunction (sh:.n) (\(_:.i) -> let x = P.fromIntegral i in (x,x)) -- Reference implementation -- ------------------------ scanlRef :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> Array (sh:.Int) e scanlRef f z arr = let sz :. n = arrayShape arr arr' = [ P.scanl f z sub | sub <- splitEvery n (toList arr) ] in A.fromList (sz :. n+1) (concat arr') scanl'Ref :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> (Array (sh:.Int) e, Array sh e) scanl'Ref f z arr = let sz :. n = arrayShape arr (arr',sums) = P.unzip [ P.splitAt n (P.scanl f z sub) | sub <- splitEvery n (toList arr) ] in ( A.fromList (sz:.n) (concat arr'), A.fromList sz (concat sums) ) scanl1Ref :: (Shape sh, Elt e) => (e -> e -> e) -> Array (sh:.Int) e -> Array (sh:.Int) e scanl1Ref f arr = let sz :. n = arrayShape arr arr' = [ P.scanl1 f sub | sub <- splitEvery n (toList arr) ] in A.fromList (sz:.n) (concat arr') scanrRef :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> Array (sh:.Int) e scanrRef f z arr = let sz :. n = arrayShape arr arr' = [ P.scanr f z sub | sub <- splitEvery n (toList arr) ] in A.fromList (sz :. n+1) (concat arr') scanr'Ref :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> (Array (sh:.Int) e, Array sh e) scanr'Ref f z arr = let sz :. n = arrayShape arr (sums,arr') = P.unzip [ P.splitAt 1 (P.scanr f z sub) | sub <- splitEvery n (toList arr) ] in ( A.fromList (sz:.n) (concat arr'), A.fromList sz (concat sums) ) scanr1Ref :: (Shape sh, Elt e) => (e -> e -> e) -> Array (sh:.Int) e -> Array (sh:.Int) e scanr1Ref f arr = let sz :. n = arrayShape arr arr' = [ P.scanr1 f sub | sub <- splitEvery n (toList arr) ] in A.fromList (sz:.n) (concat arr') -- segmented operations -- scanlSegRef :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> Segments Int -> Array (sh:.Int) e scanlSegRef f z arr seg = let sz :. n = arrayShape arr seg' = toList seg n' = P.sum $ P.map (\x -> P.fromIntegral x + 1) seg' arr' = [ P.scanl f z sec | sub <- splitEvery n (toList arr) , sec <- splitPlaces seg' sub ] in A.fromList (sz:.n') (concat arr') scanl1SegRef :: (Shape sh, Elt e) => (e -> e -> e) -> Array (sh:.Int) e -> Segments Int -> Array (sh:.Int) e scanl1SegRef f arr seg = let sz :. n = arrayShape arr seg' = toList seg n' = P.fromIntegral (P.sum seg') arr' = [ P.scanl1 f sec | sub <- splitEvery n (toList arr) , sec <- splitPlaces seg' sub ] in A.fromList (sz:.n') (concat arr') scanl'SegRef :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> Segments Int -> (Array (sh:.Int) e, Array (sh:.Int) e) scanl'SegRef f z arr seg = let sz :. n = arrayShape arr Z :. s = arrayShape seg scanl'_ v = P.splitAt (P.length v) (P.scanl f z v) (arr',sums) = P.unzip [ scanl'_ sec | sub <- splitEvery n (toList arr) , sec <- splitPlaces (toList seg) sub ] in ( A.fromList (sz:.n) (concat arr'), A.fromList (sz:.s) (concat sums) ) scanrSegRef :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> Segments Int -> Array (sh:.Int) e scanrSegRef f z arr seg = let sz :. n = arrayShape arr seg' = toList seg n' = P.sum $ P.map (\x -> P.fromIntegral x + 1) seg' arr' = [ P.scanr f z sec | sub <- splitEvery n (toList arr) , sec <- splitPlaces seg' sub ] in A.fromList (sz:.n') (concat arr') scanr1SegRef :: (Shape sh, Elt e) => (e -> e -> e) -> Array (sh:.Int) e -> Segments Int -> Array (sh:.Int) e scanr1SegRef f arr seg = let sz :. n = arrayShape arr seg' = toList seg n' = P.fromIntegral (P.sum seg') arr' = [ P.scanr1 f sec | sub <- splitEvery n (toList arr) , sec <- splitPlaces seg' sub ] in A.fromList (sz:.n') (concat arr') scanr'SegRef :: (Shape sh, Elt e) => (e -> e -> e) -> e -> Array (sh:.Int) e -> Segments Int -> (Array (sh:.Int) e, Array (sh:.Int) e) scanr'SegRef f z arr seg = let sz :. n = arrayShape arr Z :. s = arrayShape seg (sums, arr') = P.unzip [ P.splitAt 1 (P.scanr f z sec) | sub <- splitEvery n (toList arr) , sec <- splitPlaces (toList seg) sub ] in ( A.fromList (sz:.n) (concat arr'), A.fromList (sz:.s) (concat sums) )