{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Test.NoFib.Sharing (
test_sharing
) where
import Data.Array.Accelerate as A hiding ( exp )
import Data.Array.Accelerate.Trafo.Sharing
import Data.Array.Accelerate.Data.Bits as A
import Control.DeepSeq
import Control.Exception
import Test.Tasty
import Test.Tasty.ExpectedFailure
import Test.Tasty.HUnit
import Prelude as P hiding ( exp )
test_sharing :: TestTree
test_sharing =
testGroup "sharing"
[ testCase "simple" $ sharingAcc test_simple
, testCase "ordering" $ sharingAcc test_ordering
, testCase "sort" $ sharingAcc test_sort
, testCase "blowup" $ sharingAcc (test_blowup 20)
, testCase "bfs" $ sharingAcc test_bfs
, testGroup "same level"
[ testCase "1" $ sharingAcc test_two_lets_same_level_1
, testCase "2" $ sharingAcc test_two_lets_same_level_2
]
, testGroup "empty top"
[ testCase "1" $ sharingAcc test_no_let_at_top_1
, testCase "2" $ sharingAcc test_no_let_at_top_2
]
, testCase "pipe" $ sharingAcc test_pipe
, testCase "bound variables" $ sharingAcc test_bound_variables
, testCase "big tuple" $ sharingExp test_big_tuple
, testGroup "iteration"
[ testCase "simple" $ sharingAcc test_iteration_simple
, testCase "outside" $ sharingAcc test_iteration_outside
, testCase "body and condition" $ sharingAcc test_iteration_body_condition
, testCase "awhile" $ sharingAcc test_awhile
, testCase "iterate" $ sharingAcc test_iterate
, testCase "nested" $ sharingExp test_nested_iteration
, testCase "unused" $ sharingExp test_unused_iteration
]
, testGroup "nested data-parallelism"
[ expectFail $ testCase "mvm" $ sharingAcc test_nested_data_praallelism
]
]
where
sharingAcc :: Arrays a => Acc a -> Assertion
sharingAcc acc =
catch (rnf (convertAcc True True True True acc) `seq` return ())
(\(e :: SomeException) -> assertFailure (show e))
sharingExp :: Elt e => Exp e -> Assertion
sharingExp exp =
catch (rnf (convertExp True exp) `seq` return ())
(\(e :: SomeException) -> assertFailure (show e))
mkArray :: Int -> Acc (Array DIM1 Int)
mkArray n = use $ fromList (Z:.1) [n]
test_blowup :: Int -> Acc (Array DIM1 Int)
test_blowup 0 = (mkArray 0)
test_blowup n = A.map (\_ -> newArr ! (lift (Z:.(0::Int))) +
newArr ! (lift (Z:.(1::Int)))) (mkArray n)
where
newArr = test_blowup (n-1)
idx :: Int -> Exp DIM1
idx i = lift (Z:.i)
test_bfs :: Acc (Array DIM1 Int)
test_bfs = A.map (\x -> (map2 ! (idx 1)) + (map1 ! (idx 2)) + x) arr
where
map1 :: Acc (Array DIM1 Int)
map1 = A.map (\y -> (map2 ! (idx 3)) + y) arr
map2 :: Acc (Array DIM1 Int)
map2 = A.map (\z -> z + 1) arr
arr :: Acc (Array DIM1 Int)
arr = mkArray 666
test_two_lets_same_level_1 :: Acc (Array DIM1 Int)
test_two_lets_same_level_1 =
let arr1 = mkArray 1
in let arr2 = mkArray 2
in A.map (\_ -> arr1!(idx 1) + arr1!(idx 2) + arr2!(idx 3) + arr2!(idx 4)) (mkArray 3)
test_two_lets_same_level_2 :: Acc (Array DIM1 Int)
test_two_lets_same_level_2 =
let arr2 = mkArray 2
in let arr1 = mkArray 1
in A.map (\_ -> arr1!(idx 1) + arr1!(idx 2) + arr2!(idx 3) + arr2!(idx 4)) (mkArray 3)
test_no_let_at_top_1 :: Acc (Array DIM1 Int)
test_no_let_at_top_1 = A.map (\x -> x + 1) test_bfs
test_no_let_at_top_2 :: Acc (Array DIM1 Int)
test_no_let_at_top_2
= A.map (\x -> x + 2)
$ A.map (\x -> x + 1) test_bfs
test_simple :: Acc (Array DIM1 (Int,Int))
test_simple = A.map (\_ -> a ! (idx 1)) d
where
c = use $ A.fromList (Z :. 3) [1..]
d = A.map (+1) c
a = A.zip d c
sortKey :: (Elt e)
=> (Exp e -> Exp Int)
-> Acc (Vector e)
-> Acc (Vector e)
sortKey keyFun arr = foldl sortOneBit arr (P.map lift ([0..31] :: [Int]))
where
sortOneBit inArr bitNum = outArr
where
keys = A.map keyFun inArr
bits = A.map (\a -> (A.testBit a bitNum) ? (1, 0)) keys
bitsInv = A.map (\b -> (b A.== 0) ? (1, 0)) bits
(falses, numZeroes) = unlift (A.scanl' (+) 0 bitsInv)
trues = A.map (\x -> (A.the numZeroes) + (A.fst x) - (A.snd x))
$ A.zip ixs falses
dstIxs = A.map (\x -> let (b, t, f) = unlift x in (b A.== (constant (0::Int))) ? (f, t))
$ A.zip3 bits trues falses
outArr = scatter dstIxs inArr inArr
ixs = enumeratedArray (shape arr)
enumeratedArray :: Exp DIM1 -> Acc (Array DIM1 Int)
enumeratedArray sh = A.generate sh unindex1
test_sort :: Acc (Vector Int)
test_sort = sortKey id $ use $ fromList (Z:.10) [9,8,7,6,5,4,3,2,1,0]
test_ordering :: Acc (Array DIM1 Int)
test_ordering = A.map (\_ -> map1 ! (idx 1) + map2 ! (idx 1)) arr
where
map1 = A.map (\_ -> map3 ! (idx 1) + map2 ! (idx 2)) arr
map2 = A.map (\_ -> map3 ! (idx 3)) arr
map3 = A.map (+1) arr
arr = mkArray 42
test_pipe :: Acc (Vector Int)
test_pipe = (acc1 >-> acc2) xs
where
z :: Acc (Scalar Int)
z = unit 0
xs :: Acc (Vector Int)
xs = use $ fromList (Z:.10) [0..]
acc1 :: Acc (Vector Int) -> Acc (Vector Int)
acc1 = A.map (\_ -> the z)
acc2 :: Acc (Vector Int) -> Acc (Vector Int)
acc2 arr = let arr2 = use $ fromList (Z:.10) [10..]
in A.map (\_ -> arr2!constant (Z:.(0::Int))) (A.zip arr arr2)
test_bound_variables :: Acc (Array DIM2 Int, Array DIM2 Float, Array DIM2 Float)
test_bound_variables = lift (first, both, second)
where
is :: Array DIM2 Int
is = fromList (Z:.10:.10) [0..]
fs :: Array DIM2 Float
fs = fromList (Z:.10:.10) [0..]
first = stencil2 centre clamp (use fs) clamp (use is)
where
centre :: Stencil3x3 Float -> Stencil3x3 Int -> Exp Int
centre _ (_,(_,y,_),_) = y
both = stencil2 centre clamp (use fs) clamp (use is)
where
centre :: Stencil3x3 Float -> Stencil3x3 Int -> Exp Float
centre (_,(_,x,_),_) (_,(_,y,_),_) = x + A.fromIntegral y
second = stencil2 centre clamp (use fs) clamp (use is)
where
centre :: Stencil3x3 Float -> Stencil3x3 Int -> Exp Float
centre (_,(_,x,_),_) _ = x
test_big_tuple :: Exp ((Int,Int,Int,Int,Int,Int,Int,Int), (Int,Int,Int,Int,Int,Int,Int,Int,Int))
test_big_tuple = lift (A.constant (0,0,0,0,0,0,0,0), A.constant (0,0,0,0,0,0,0,0,0))
v1 :: Acc (Vector Float)
v1 = use $ fromList (Z:.10) [0..]
test_iteration_simple :: Acc (Vector Float)
test_iteration_simple
= flip A.map v1
$ \x -> A.while (A.< x) (+1) 0
test_iteration_outside :: Acc (Vector Float)
test_iteration_outside
= flip A.map v1
$ \x -> let y = 2*pi
in y + A.while (A.< 10) (+y) x
test_iteration_body_condition :: Acc (Vector Float)
test_iteration_body_condition
= flip A.map v1
$ \x -> A.while (A.< x) (+x) 0
test_awhile :: Acc (Vector Float)
test_awhile = A.awhile (\a -> A.unit (the (A.sum a) A.< 200)) (A.map (+1)) v1
test_iterate :: Acc (Vector Float)
test_iterate
= flip A.map v1
$ \x -> let y = 2*x
in y + A.iterate (constant 10) (\x' -> y + x' + 10) x
test_for :: Elt a => Exp Int -> (Exp Int -> Exp a -> Exp a) -> Exp a -> Exp a
test_for n f seed
= A.snd
$ A.iterate n (\v -> let (i, x) = unlift v
in lift (i+1, f i x))
(lift (constant 0, seed))
test_nested_iteration :: Exp Int
test_nested_iteration =
test_for 64 (\i _ ->
test_for 64 (\j acc' -> i + j + acc') 0) 0
test_unused_iteration :: Exp Int
test_unused_iteration =
A.while (A.== 10) (const 10) 5
test_nested_data_praallelism :: Acc (Vector Float)
test_nested_data_praallelism =
mvm (use $ fromList (Z:.10:.10) [0..]) (use $ fromList (Z:.10) [0..])
where
dotp :: A.Num e => Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
dotp xs ys = A.fold (+) 0 $ A.zipWith (*) xs ys
takeRow :: Elt e => Exp Int -> Acc (Array DIM2 e) -> Acc (Vector e)
takeRow n mat =
let Z :. _ :. cols = unlift (shape mat) :: Z:. Exp Int :. Exp Int
in backpermute (index1 cols)
(\ix -> index2 n (unindex1 ix))
mat
mvm :: A.Num e => Acc (Array DIM2 e) -> Acc (Vector e) -> Acc (Vector e)
mvm mat vec =
let Z :. rows :. _ = unlift (shape mat) :: Z :. Exp Int :. Exp Int
in generate (index1 rows)
(\ix -> the (vec `dotp` takeRow (unindex1 ix) mat))