module Data.Array.Repa.Algorithms.Matrix
(multiplyMM)
where
import Data.Array.Repa as A
multiplyMM
:: Array DIM2 Double
-> Array DIM2 Double
-> Array DIM2 Double
multiplyMM arr@(Array _ [Region RangeAll (GenManifest _)])
brr@(Array _ [Region RangeAll (GenManifest _)])
= [arr, brr] `deepSeqArrays`
A.force $ A.sum (A.zipWith (*) arrRepl brrRepl)
where trr@(Array _ [Region RangeAll (GenManifest _)])
= force $ transpose2D brr
arrRepl = trr `deepSeqArray` A.extend (Z :. All :. colsB :. All) arr
brrRepl = trr `deepSeqArray` A.extend (Z :. rowsA :. All :. All) trr
(Z :. _ :. rowsA) = extent arr
(Z :. colsB :. _ ) = extent brr
transpose2D :: Elt e => Array DIM2 e -> Array DIM2 e
transpose2D arr
= backpermute new_extent swap arr
where swap (Z :. i :. j) = Z :. j :. i
new_extent = swap (extent arr)