module Data.Array.Repa.Algorithms.Matrix
(
row
, col
, mmultP, mmultS
, transpose2P, transpose2S)
where
import Data.Array.Repa as R
row :: DIM2 -> Int
row (Z :. r :. _) = r
col :: DIM2 -> Int
col (Z :. _ :. c) = c
mmultP :: Array U DIM2 Double
-> Array U DIM2 Double
-> Array U DIM2 Double
mmultP arr' brr
= mmult' arr' (transpose2P brr)
where mmult' arr trr
= trr `deepSeqArray` computeP
$ fromFunction (extent arr)
$ \ix -> R.sumAllS
$ R.zipWith (*)
(slice arr (Any :. (row ix) :. All))
(slice trr (Any :. (col ix) :. All))
mmultS :: Array U DIM2 Double
-> Array U DIM2 Double
-> Array U DIM2 Double
mmultS arr' brr
= mmult' arr' (transpose2S brr)
where mmult' arr trr
= trr `deepSeqArray` computeS
$ fromFunction (extent arr)
$ \ix -> R.sumAllS
$ R.zipWith (*)
(slice arr (Any :. (row ix) :. All))
(slice trr (Any :. (col ix) :. All))
transpose2P
:: Array U DIM2 Double
-> Array U DIM2 Double
transpose2P arr
= arr `deepSeqArray`
computeUnboxedP
$ unsafeBackpermute new_extent swap arr
where swap (Z :. i :. j) = Z :. j :. i
new_extent = swap (extent arr)
transpose2S
:: Array U DIM2 Double
-> Array U DIM2 Double
transpose2S arr
= arr `deepSeqArray`
computeUnboxedS
$ unsafeBackpermute new_extent swap arr
where swap (Z :. i :. j) = Z :. j :. i
new_extent = swap (extent arr)