module Data.Array.Repa
( module Data.Array.Repa.Shape
, module Data.Array.Repa.Index
, module Data.Array.Repa.Slice
, Array (..)
, fromUArray
, fromFunction
, unit
, extent
, delay
, toUArray
, index, (!:)
, toScalar
, force
, isManifest
, deepSeqArray
, fromList
, toList
, reshape
, append, (+:+)
, transpose
, replicate
, slice
, backpermute
, backpermuteDft
, map
, zipWith
, fold
, sum
, sumAll
, traverse
, traverse2
, arbitrarySmallArray
, props_DataArrayRepa)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Shape
import Data.Array.Repa.QuickCheck
import qualified Data.Array.Repa.Shape as S
import "dph-prim-par" Data.Array.Parallel.Unlifted (Elt)
import qualified "dph-prim-par" Data.Array.Parallel.Unlifted as U
import qualified "dph-prim-seq" Data.Array.Parallel.Unlifted.Sequential as USeq
import Test.QuickCheck
import Prelude hiding (sum, map, zipWith, replicate)
import qualified Prelude as P
stage = "Data.Array.Repa"
data Array sh a
=
Manifest sh (U.Array a)
| Delayed sh (sh -> a)
fromUArray
:: Shape sh
=> sh
-> U.Array a
-> Array sh a
fromUArray sh uarr
= sh `S.deepSeq`
uarr `seq`
Manifest sh uarr
fromFunction
:: Shape sh
=> sh
-> (sh -> a)
-> Array sh a
fromFunction sh fnElems
= sh `S.deepSeq` Delayed sh fnElems
unit :: Elt a => a -> Array Z a
unit = Delayed Z . const
extent :: Array sh a -> sh
extent arr
= case arr of
Manifest sh _ -> sh
Delayed sh _ -> sh
delay :: (Shape sh, Elt a)
=> Array sh a
-> (sh, sh -> a)
delay arr
= case arr of
Manifest sh uarr -> (sh, \i -> uarr U.!: S.toIndex sh i)
Delayed sh fn -> (sh, fn)
toUArray
:: (Shape sh, Elt a)
=> Array sh a
-> U.Array a
toUArray arr
= case force arr of
Manifest _ uarr -> uarr
_ -> error $ stage ++ ".toList: force failed"
index, (!:)
:: forall sh a
. (Shape sh, Elt a)
=> Array sh a
-> sh
-> a
index arr ix
= case arr of
Delayed _ fn -> fn ix
Manifest sh uarr -> uarr U.!: (S.toIndex sh ix)
(!:) arr ix = index arr ix
toScalar :: Elt a => Array Z a -> a
toScalar arr
= case arr of
Delayed _ fn -> fn Z
Manifest _ uarr -> uarr U.!: 0
force :: (Shape sh, Elt a)
=> Array sh a -> Array sh a
force arr
= case arr of
Manifest sh uarr
-> sh `S.deepSeq` uarr `seq`
Manifest sh uarr
Delayed sh fn
-> let uarr = U.map (fn . S.fromIndex sh)
$! U.enumFromTo (0 :: Int) (S.size sh 1)
in sh `S.deepSeq` uarr `seq`
Manifest sh uarr
isManifest :: Array sh a -> Array sh a
isManifest arr
= case arr of
Manifest{} -> arr
_ -> error "not manifest"
infixr 0 `deepSeqArray`
deepSeqArray
:: Shape sh
=> Array sh a
-> b -> b
deepSeqArray arr x
= case arr of
Delayed sh _ -> sh `S.deepSeq` x
Manifest sh uarr -> sh `S.deepSeq` uarr `seq` x
fromList
:: (Shape sh, Elt a)
=> sh
-> [a]
-> Array sh a
fromList sh xx
| U.length uarr /= S.size sh
= error $ unlines
[ stage ++ ".fromList: size of array shape does not match size of list"
, " size of shape = " ++ (show $ S.size sh) ++ "\n"
, " size of list = " ++ (show $ U.length uarr) ++ "\n" ]
| otherwise
= Manifest sh uarr
where uarr = U.fromList xx
toList :: (Shape sh, Elt a)
=> Array sh a
-> [a]
toList arr
= case force arr of
Manifest _ uarr -> U.toList uarr
_ -> error $ stage ++ ".toList: force failed"
instance (Shape sh, Elt a, Show a) => Show (Array sh a) where
show arr = show $ toList arr
instance (Shape sh, Elt a, Eq a) => Eq (Array sh a) where
(==) arr1 arr2
= toScalar
$ fold (&&) True
$ (flip reshape) (Z :. (S.size $ extent arr1))
$ zipWith (==) arr1 arr2
(/=) a1 a2
= not $ (==) a1 a2
instance (Shape sh, Elt a, Num a) => Num (Array sh a) where
(+) = zipWith (+)
() = zipWith ()
(*) = zipWith (*)
negate = map negate
abs = map abs
signum = map signum
fromInteger n = Delayed failShape (\_ -> fromInteger n)
where failShape = error $ stage ++ ".fromInteger: Constructed array has no shape."
reshape :: (Shape sh, Shape sh', Elt a)
=> Array sh a
-> sh'
-> Array sh' a
reshape arr newExtent
| not $ S.size newExtent == S.size (extent arr)
= error $ stage ++ ".reshape: reshaped array will not match size of the original"
| otherwise
= Delayed newExtent
$ ((arr !:) . (S.fromIndex (extent arr)) . (S.toIndex newExtent))
append, (+:+)
:: (Shape sh, Elt a)
=> Array (sh :. Int) a
-> Array (sh :. Int) a
-> Array (sh :. Int) a
append arr1 arr2
= traverse2 arr1 arr2 fnExtent fnElem
where
(_ :. n) = extent arr1
fnExtent (sh :. i) (_ :. j)
= sh :. (i + j)
fnElem f1 f2 (sh :. i)
| i < n = f1 (sh :. i)
| otherwise = f2 (sh :. (i n))
(+:+) arr1 arr2 = append arr1 arr2
transpose
:: (Shape sh, Elt a)
=> Array (sh :. Int :. Int) a
-> Array (sh :. Int :. Int) a
transpose arr
= traverse arr
(\(sh :. m :. n) -> (sh :. n :.m))
(\f -> \(sh :. i :. j) -> f (sh :. j :. i))
replicate
:: ( Slice sl
, Shape (FullShape sl)
, Shape (SliceShape sl)
, Elt e)
=> sl
-> Array (SliceShape sl) e
-> Array (FullShape sl) e
replicate sl arr
= backpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
slice :: ( Slice sl
, Shape (FullShape sl)
, Shape (SliceShape sl)
, Elt e)
=> Array (FullShape sl) e
-> sl
-> Array (SliceShape sl) e
slice arr sl
= backpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
backpermute
:: forall sh sh' a
. (Shape sh, Shape sh', Elt a)
=> sh'
-> (sh' -> sh)
-> Array sh a
-> Array sh' a
backpermute newExtent perm arr
= traverse arr (const newExtent) (. perm)
backpermuteDft
:: forall sh sh' a
. (Shape sh, Shape sh', Elt a)
=> Array sh' a
-> (sh' -> Maybe sh)
-> Array sh a
-> Array sh' a
backpermuteDft arrDft fnIndex arrSrc
= Delayed (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc !: ix'
Nothing -> arrDft !: ix
map :: (Shape sh, Elt a, Elt b)
=> (a -> b)
-> Array sh a
-> Array sh b
map f arr
= Delayed (extent arr) (f . (arr !:))
zipWith :: (Shape sh, Elt a, Elt b, Elt c)
=> (a -> b -> c)
-> Array sh a
-> Array sh b
-> Array sh c
zipWith f arr1 arr2
= arr1 `deepSeqArray`
arr2 `deepSeqArray`
Delayed (S.intersectDim (extent arr1) (extent arr2))
(\ix -> f (arr1 !: ix) (arr2 !: ix))
fold :: (Shape sh, Elt a)
=> (a -> a -> a)
-> a
-> Array (sh :. Int) a
-> Array sh a
fold f x arr
= x `seq` arr `deepSeqArray`
let sh' :. n = extent arr
elemFn i = USeq.foldU f x
$ USeq.mapU
(\ix -> arr !: (i :. ix))
(USeq.enumFromToU 0 (n 1))
in Delayed sh' elemFn
sum :: (Shape sh, Elt a, Num a)
=> Array (sh :. Int) a
-> Array sh a
sum arr = fold (+) 0 arr
sumAll :: (Shape sh, Elt a, Num a)
=> Array sh a
-> a
sumAll arr
= USeq.foldU (+) 0
$ USeq.mapU ((arr !:) . (S.fromIndex (extent arr)))
$ USeq.enumFromToU
0
((S.size $ extent arr) 1)
traverse
:: forall sh sh' a b
. (Shape sh, Shape sh', Elt a)
=> Array sh a
-> (sh -> sh')
-> ((sh -> a) -> sh' -> b)
-> Array sh' b
traverse arr transExtent newElem
= arr `deepSeqArray`
Delayed
(transExtent (extent arr))
(newElem (arr !:))
traverse2
:: forall sh sh' sh'' a b c
. ( Shape sh, Shape sh', Shape sh''
, Elt a, Elt b, Elt c)
=> Array sh a
-> Array sh' b
-> (sh -> sh' -> sh'')
-> ((sh -> a) -> (sh' -> b)
-> (sh'' -> c))
-> Array sh'' c
traverse2 arrA arrB transExtent newElem
= arrA `deepSeqArray` arrB `deepSeqArray`
Delayed
(transExtent (extent arrA) (extent arrB))
(newElem ((!:) arrA) ((!:) arrB))
arbitrarySmallArray
:: (Shape sh, Elt a, Arbitrary sh, Arbitrary a)
=> Int
-> Gen (Array (sh :. Int) a)
arbitrarySmallArray maxDim
= do sh <- arbitrarySmallShape maxDim
xx <- arbitraryListOfLength (S.size sh)
return $ fromList sh xx
props_DataArrayRepa :: [(String, Property)]
props_DataArrayRepa
= props_DataArrayRepaIndex
++ [(stage ++ "." ++ name, test) | (name, test)
<- [ ("id_force/DIM5", property prop_id_force_DIM5)
, ("id_toScalarUnit", property prop_id_toScalarUnit)
, ("id_toListFromList/DIM3", property prop_id_toListFromList_DIM3)
, ("id_transpose/DIM4", property prop_id_transpose_DIM4)
, ("reshapeTransposeSize/DIM3", property prop_reshapeTranspose_DIM3)
, ("appendIsAppend/DIM3", property prop_appendIsAppend_DIM3)
, ("sumAllIsSum/DIM3", property prop_sumAllIsSum_DIM3) ]]
prop_id_force_DIM5
= forAll (arbitrarySmallArray 10) $ \(arr :: Array DIM5 Int) ->
arr == force arr
prop_id_toScalarUnit (x :: Int)
= toScalar (unit x) == x
prop_id_toListFromList_DIM3
= forAll (arbitrarySmallShape 10) $ \(sh :: DIM3) ->
forAll (arbitraryListOfLength (S.size sh)) $ \(xx :: [Int]) ->
toList (fromList sh xx) == xx
prop_id_transpose_DIM4
= forAll (arbitrarySmallArray 20) $ \(arr :: Array DIM3 Int) ->
transpose (transpose arr) == arr
prop_reshapeTranspose_DIM3
= forAll (arbitrarySmallArray 20) $ \(arr :: Array DIM3 Int) ->
let arr' = transpose arr
sh' = extent arr'
in (S.size $ extent arr) == S.size (extent (reshape arr sh'))
&& (sumAll arr == sumAll arr')
prop_appendIsAppend_DIM3
= forAll (arbitrarySmallArray 20) $ \(arr1 :: Array DIM3 Int) ->
sumAll (append arr1 arr1) == (2 * sumAll arr1)
prop_sumAllIsSum_DIM3
= forAll (arbitrarySmallShape 100) $ \(sh :: DIM2) ->
forAll (arbitraryListOfLength (S.size sh)) $ \(xx :: [Int]) ->
sumAll (fromList sh xx) == P.sum xx