{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Test.NoFib.Issues.Issue185 (
test_issue185
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Test.NoFib.Base
import Test.Tasty
import Test.Tasty.HUnit
import Prelude as P
test_issue185 :: RunN -> TestTree
test_issue185 runN =
testGroup "185"
[ testCase "A" $ ref1 @=? (runN acc1 :: Vector Int)
, testCase "B" $ ref2 @=? (runN acc2 :: Vector Int)
, testCase "C" $ ref3 @=? (runN acc3 :: Vector Int)
, testCase "D" $ ref4 @=? (runN acc4 :: Vector Int)
, testCase "E" $ ref5 @=? (runN acc5 :: Vector Int)
, testCase "F" $ ref6 @=? (runN acc6 :: Vector Int)
]
ref1 :: (Elt a, P.Num a) => Vector a
ref1 = fromList (Z :. 9) [0,1,4,9,0,4,0,6,2]
acc1 :: (Elt a, P.Num a) => Acc (Vector a)
acc1 = A.scatter to over xs
where
over = use [0, 0, 0, 0, 0, 0, 0, 0, 0]
xs = use [1, 9, 6, 4, 4, 2, 5]
to = use [1, 3, 7, 2, 5, 8]
ref2 :: (Elt a, P.Num a) => Vector a
ref2 = fromList (Z :. 9) [0,1,0,9,0,0,0,6,0]
acc2 :: (Elt a, P.Num a) => Acc (Vector a)
acc2 = A.scatter to over xs
where
over = use [0, 0, 0, 0, 0, 0, 0, 0, 0]
xs = use [1, 9, 6]
to = use [1, 3, 7, 2, 5, 8]
ref3 :: (Elt a, P.Num a) => Vector a
ref3 = fromList (Z :. 9) [0,0,0,0,0,4,0,6,2]
acc3 :: (Elt a, P.Num a) => Acc (Vector a)
acc3 = scatterIf to mask p over xs
where
over = use [0, 0, 0, 0, 0, 0, 0, 0, 0]
to = use [1, 3, 7, 2, 5, 8]
xs = use [1, 9, 6, 4, 4, 2, 5]
mask :: Acc (Vector Int32)
mask = use [3, 4, 9, 2, 7, 5]
p = (A.> 4)
ref4 :: (Elt a, P.Num a) => Vector a
ref4 = fromList (Z :. 9) [0,0,0,0,0,0,0,6,0]
acc4 :: (Elt a, P.Num a) => Acc (Vector a)
acc4 = scatterIf to mask p over xs
where
over = use [0, 0, 0, 0, 0, 0, 0, 0, 0]
to = use [1, 3, 7, 2, 5, 8]
xs = use [1, 9, 6]
mask :: Acc (Vector Int32)
mask = use [3, 4, 9, 2, 7, 5]
p = (A.> 4)
ref5 :: (Elt a, P.Num a) => Vector a
ref5 = fromList (Z :. 6) [9,4,1,6,2,4]
acc5 :: (Elt a, P.Num a) => Acc (Vector a)
acc5 = A.gather from xs
where
from = use [1, 3, 7, 2, 5, 3]
xs = use [1, 9, 6, 4, 4, 2, 0, 1, 2]
ref6 :: (Elt a, P.Num a) => Vector a
ref6 = fromList (Z :. 6) [6,6,1,6,2,4]
acc6 :: (Elt a, P.Num a) => Acc (Vector a)
acc6 = gatherIf from mask p over xs
where
over = use [6, 6, 6, 6, 6, 6]
from = use [1, 3, 7, 2, 5, 3]
xs = use [1, 9, 6, 4, 4, 2, 0, 1, 2]
mask :: Acc (Vector Int32)
mask = use [3, 4, 9, 2, 7, 5]
p = (A.> 4)
gatherIf
:: (Elt a, Elt b)
=> Acc (Vector Int)
-> Acc (Vector a)
-> (Exp a -> Exp Bool)
-> Acc (Vector b)
-> Acc (Vector b)
-> Acc (Vector b)
gatherIf from maskV p def input = A.zipWith zf pf gatheredV
where
zf ok g = ok ? (unlift g)
gatheredV = A.zip (gather from input) def
pf = A.map p maskV
scatterIf
:: (Elt e, Elt e')
=> Acc (Vector Int)
-> Acc (Vector e)
-> (Exp e -> Exp Bool)
-> Acc (Vector e')
-> Acc (Vector e')
-> Acc (Vector e')
scatterIf to maskV p def input = permute const def pf input'
where
pf ix = p (maskV ! ix) ? ( index1 (to ! ix), ignore )
input' = backpermute (shape to `intersect` shape input) id input