module Data.Array.Repa.Algorithms.Matrix
(
row
, col
, mmultP, mmultS
, transpose2P, transpose2S)
where
import Data.Array.Repa as R
import Data.Array.Repa.Eval as R
import Data.Array.Repa.Unsafe as R
import Control.Monad.ST.Strict
row :: DIM2 -> Int
row (Z :. r :. _) = r
col :: DIM2 -> Int
col (Z :. _ :. c) = c
mmultP :: Monad m
=> Array U DIM2 Double
-> Array U DIM2 Double
-> m (Array U DIM2 Double)
mmultP arr brr
= [arr, brr] `deepSeqArrays`
do trr <- transpose2P brr
let (Z :. h1 :. _) = extent arr
let (Z :. _ :. w2) = extent brr
computeP
$ fromFunction (Z :. h1 :. w2)
$ \ix -> R.sumAllS
$ R.zipWith (*)
(unsafeSlice arr (Any :. (row ix) :. All))
(unsafeSlice trr (Any :. (col ix) :. All))
mmultS :: Array U DIM2 Double
-> Array U DIM2 Double
-> Array U DIM2 Double
mmultS arr brr
= [arr, brr] `deepSeqArrays` runST $
do trr <- R.now $ transpose2S brr
let (Z :. h1 :. _) = extent arr
let (Z :. _ :. w2) = extent brr
return $ computeS
$ fromFunction (Z :. h1 :. w2)
$ \ix -> R.sumAllS
$ R.zipWith (*)
(unsafeSlice arr (Any :. (row ix) :. All))
(unsafeSlice trr (Any :. (col ix) :. All))
transpose2P
:: Monad m
=> Array U DIM2 Double
-> m (Array U DIM2 Double)
transpose2P arr
= arr `deepSeqArray`
do computeUnboxedP
$ unsafeBackpermute new_extent swap arr
where swap (Z :. i :. j) = Z :. j :. i
new_extent = swap (extent arr)
transpose2S
:: Array U DIM2 Double
-> Array U DIM2 Double
transpose2S arr
= arr `deepSeqArray`
do computeUnboxedS
$ unsafeBackpermute new_extent swap arr
where swap (Z :. i :. j) = Z :. j :. i
new_extent = swap (extent arr)