{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Test.NoFib.Prelude.Fold (
test_fold,
test_foldSeg,
) where
import Data.Proxy
import Data.Typeable
import Prelude as P
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Test.NoFib.Base
import Data.Array.Accelerate.Test.NoFib.Config
import Data.Array.Accelerate.Test.Similar
import Hedgehog
import qualified Hedgehog.Gen as Gen
import qualified Hedgehog.Range as Range
import Test.Tasty
import Test.Tasty.Hedgehog
test_fold :: RunN -> TestTree
test_fold runN =
testGroup "fold"
[ at (Proxy::Proxy TestInt8) $ testElt i8 (Gen.int8 (Range.linearFrom 0 (-1) 1))
, at (Proxy::Proxy TestInt16) $ testElt i16 (Gen.int16 (Range.linearFrom 0 (-10) 10))
, at (Proxy::Proxy TestInt32) $ testElt i32 (Gen.int32 (Range.linearFrom 0 (-1000) 1000))
, at (Proxy::Proxy TestInt64) $ testElt i64 (Gen.int64 (Range.linearFrom 0 (-10000) 10000))
, at (Proxy::Proxy TestWord8) $ testElt w8 (Gen.word8 (Range.linear 0 1))
, at (Proxy::Proxy TestWord16) $ testElt w16 (Gen.word16 (Range.linear 0 10))
, at (Proxy::Proxy TestWord32) $ testElt w32 (Gen.word32 (Range.linear 0 1000))
, at (Proxy::Proxy TestWord64) $ testElt w64 (Gen.word64 (Range.linear 0 10000))
, at (Proxy::Proxy TestHalf) $ testElt f16 f16
, at (Proxy::Proxy TestFloat) $ testElt f32 f32
, at (Proxy::Proxy TestDouble) $ testElt f64 f64
]
where
testElt
:: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a)
=> Gen a
-> Gen a
-> TestTree
testElt e small =
testGroup (show (typeOf (undefined :: a)))
[ testDim dim1
, testDim dim2
, testDim dim3
]
where
testDim
:: forall sh. (Shape sh, P.Eq sh)
=> Gen (sh:.Int)
-> TestTree
testDim sh =
testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int))))
[
testProperty "sum" $ test_sum runN sh (return 0) e
, testProperty "non-neutral sum" $ test_sum runN sh e e
, testProperty "non-commutative" $ test_mss runN sh small
, testProperty "minimum" $ test_minimum runN sh e
, testProperty "maximum" $ test_maximum runN sh e
]
test_foldSeg :: RunN -> TestTree
test_foldSeg runN =
testGroup "foldSeg"
[ at (Proxy::Proxy TestInt8) $ testElt i8
, at (Proxy::Proxy TestInt16) $ testElt i16
, at (Proxy::Proxy TestInt32) $ testElt i32
, at (Proxy::Proxy TestInt64) $ testElt i64
, at (Proxy::Proxy TestWord8) $ testElt w8
, at (Proxy::Proxy TestWord16) $ testElt w16
, at (Proxy::Proxy TestWord32) $ testElt w32
, at (Proxy::Proxy TestWord64) $ testElt w64
, at (Proxy::Proxy TestFloat) $ testElt f32
, at (Proxy::Proxy TestDouble) $ testElt f64
]
where
testElt :: forall a. (P.Num a, P.Ord a , A.Num a, A.Ord a , Similar a)
=> Gen a
-> TestTree
testElt e =
testGroup (show (typeOf (undefined :: a)))
[ testDim dim1
, testDim dim2
, testDim dim3
]
where
testDim
:: forall sh. (Shape sh, P.Eq sh)
=> Gen (sh:.Int)
-> TestTree
testDim sh =
testGroup ("DIM" P.++ show (rank (undefined::(sh:.Int))))
[
testProperty "sum" $ test_segmented_sum runN sh (return 0) e
, testProperty "non-neutral sum" $ test_segmented_sum runN sh e e
, testProperty "minimum" $ test_segmented_minimum runN sh e
, testProperty "maximum" $ test_segmented_maximum runN sh e
]
scalar :: Elt e => e -> Scalar e
scalar x = fromFunction Z (const x)
test_sum
:: (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Gen e
-> Property
test_sum runN dim z e =
property $ do
x <- forAll z
sh <- forAll dim
xs <- forAll (array sh e)
let !go = runN (\v -> A.fold (+) (the v)) in go (scalar x) xs ~~~ foldRef (+) x xs
test_mss
:: (Shape sh, Similar e, P.Eq sh, P.Num e, P.Ord e, A.Num e, A.Ord e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Property
test_mss runN dim e =
property $ do
sh <- forAll (dim `except` \(_:.v) -> v P.== 0)
xs <- forAll (array sh e)
let !go = runN maximumSegmentSum in go xs ~~~ maximumSegmentSumRef xs
test_minimum
:: (Shape sh, Similar e, P.Eq sh, P.Num e, P.Ord e, A.Num e, A.Ord e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Property
test_minimum runN dim e =
property $ do
sh <- forAll (dim `except` \(_:.v) -> v P.== 0)
xs <- forAll (array sh e)
let !go = runN A.minimum in go xs ~~~ fold1Ref P.min xs
test_maximum
:: (Shape sh, Similar e, P.Eq sh, P.Num e, P.Ord e, A.Num e, A.Ord e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Property
test_maximum runN dim e =
property $ do
sh <- forAll (dim `except` \(_:.v) -> v P.== 0)
xs <- forAll (array sh e)
let !go = runN A.maximum in go xs ~~~ fold1Ref P.max xs
test_segmented_sum
:: forall sh e. (Shape sh, Similar e, P.Eq sh, P.Num e, A.Num e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Gen e
-> Property
test_segmented_sum runN dim z e =
property $ do
x <- forAll z
sh:.n1 <- forAll dim
n2 <- forAll (Gen.int (Range.linear 0 64))
n <- return (P.min n1 n2)
seg <- forAll (array (Z:.n) (Gen.int (Range.linear 0 (128 `quot` 2 P.^ (rank (undefined::sh))))))
xs <- forAll (array (sh:.P.sum (toList seg)) e)
let !go = runN (\v -> A.foldSeg (+) (the v)) in go (scalar x) xs seg ~~~ foldSegRef (+) x xs seg
test_segmented_minimum
:: forall sh e. (Shape sh, Similar e, P.Eq sh, P.Num e, P.Ord e, A.Num e, A.Ord e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Property
test_segmented_minimum runN dim e =
property $ do
sh:.n1 <- forAll dim
n2 <- forAll (Gen.int (Range.linear 0 64))
n <- return (P.min n1 n2)
seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank (undefined::sh))))))
xs <- forAll (array (sh:.P.sum (toList seg)) e)
let !go = runN (A.fold1Seg A.min) in go xs seg ~~~ fold1SegRef P.min xs seg
test_segmented_maximum
:: forall sh e. (Shape sh, Similar e, P.Eq sh, P.Num e, P.Ord e, A.Num e, A.Ord e)
=> RunN
-> Gen (sh:.Int)
-> Gen e
-> Property
test_segmented_maximum runN dim e =
property $ do
sh:.n1 <- forAll dim
n2 <- forAll (Gen.int (Range.linear 0 64))
n <- return (P.min n1 n2)
seg <- forAll (array (Z:.n) (Gen.int (Range.linear 1 (128 `quot` 2 P.^ (rank (undefined::sh))))))
xs <- forAll (array (sh:.P.sum (toList seg)) e)
let !go = runN (A.fold1Seg A.max) in go xs seg ~~~ fold1SegRef P.max xs seg
foldRef
:: (Shape sh, Elt e)
=> (e -> e -> e)
-> e
-> Array (sh :. Int) e
-> Array sh e
foldRef f z arr =
let (sh :. n) = arrayShape arr
in fromList sh [ foldl f z sub | sub <- splitEvery n (toList arr) ]
fold1Ref
:: (Shape sh, Elt e)
=> (e -> e -> e)
-> Array (sh :. Int) e
-> Array sh e
fold1Ref f arr =
let (sh :. n) = arrayShape arr
in fromList sh [ foldl1 f sub | sub <- splitEvery n (toList arr) ]
foldSegRef
:: (Shape sh, Elt e)
=> (e -> e -> e)
-> e
-> Array (sh :. Int) e
-> Segments Int
-> Array (sh :. Int) e
foldSegRef f z arr seg =
let
(sh :. n) = arrayShape arr
(Z :. sz) = arrayShape seg
seg' = toList seg
arr' = [ foldl f z sec | sub <- splitEvery n (toList arr)
, sec <- splitPlaces seg' sub ]
in
fromList (sh :. sz) arr'
fold1SegRef
:: (Shape sh, Elt e)
=> (e -> e -> e)
-> Array (sh :. Int) e
-> Segments Int
-> Array (sh :. Int) e
fold1SegRef f arr seg =
let
(sh :. n) = arrayShape arr
(Z :. sz) = arrayShape seg
seg' = toList seg
arr' = [ foldl1 f sec | sub <- splitEvery n (toList arr)
, sec <- splitPlaces seg' sub ]
in
fromList (sh :. sz) arr'
maximumSegmentSum
:: forall sh e. (Shape sh, A.Num e, A.Ord e)
=> Acc (Array (sh :. Int) e)
-> Acc (Array sh e)
maximumSegmentSum
= A.map (\v -> let (x,_,_,_) = unlift v :: (Exp e, Exp e, Exp e, Exp e) in x)
. A.fold1 f
. A.map g
where
f :: (A.Num a, A.Ord a) => Exp (a,a,a,a) -> Exp (a,a,a,a) -> Exp (a,a,a,a)
f x y =
let (mssx, misx, mcsx, tsx) = unlift x
(mssy, misy, mcsy, tsy) = unlift y
in
lift ( mssx `A.max` (mssy `A.max` (mcsx+misy))
, misx `A.max` (tsx+misy)
, mcsy `A.max` (mcsx+tsy)
, tsx+tsy
)
g :: (A.Num a, A.Ord a) => Exp a -> Exp (a,a,a,a)
g x = let y = A.max x 0
in lift (y,y,y,x)
maximumSegmentSumRef
:: (P.Num e, P.Ord e, Shape sh, Elt e)
=> Array (sh :. Int) e
-> Array sh e
maximumSegmentSumRef arr = fromList sh [ go 0 0 sub | sub <- splitEvery n (toList arr) ]
where
sh :. n = arrayShape arr
go _ v [] = v
go u v (x:xs) =
let u' = 0 `P.max` (u+x)
v' = v `P.max` u'
in
go u' v' xs