module Data.Array.Accelerate.Arithmetic.Sparse where
import qualified Data.Array.Accelerate.Arithmetic.LinearAlgebra as LinAlg
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Utility.Lift.Exp (atom, )
import Data.Array.Accelerate.Arithmetic.LinearAlgebra
(Matrix, Vector, matrixShape, )
import Data.Array.Accelerate
(Exp, Any(Any), All(All), (:.)((:.)), )
data ColumnMatrix ix a =
ColumnMatrix {numRows :: Exp Int, columnMatrix :: Matrix ix (Int, a)}
realIndex ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
Matrix ix (Int, a) ->
Matrix ix (ix :. Int)
realIndex m =
A.zipWith Exp.indexCons
(A.generate (A.shape m) (A.indexTail . A.indexTail))
(A.map A.fst m)
multiplyColumnMatrixVector ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
ColumnMatrix ix a ->
Vector ix a ->
Vector ix a
multiplyColumnMatrixVector (ColumnMatrix rows m) v =
Arrange.scatter (+)
(realIndex m)
(case matrixShape m of
sh :. _rows :. _cols -> A.fill (A.lift $ sh :. rows) 0) $
A.zipWith (*)
(A.map A.snd m)
(A.replicate (A.lift $ Any :. LinAlg.numRows m :. All) v)
transposeColumnMatrix ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
ColumnMatrix ix a ->
RowMatrix ix a
transposeColumnMatrix (ColumnMatrix n x) =
RowMatrix n $ LinAlg.transpose x
data RowMatrix ix a =
RowMatrix {numCols :: Exp Int, rowMatrix :: Matrix ix (Int, a)}
multiplyRowMatrixVector ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
RowMatrix ix a ->
Vector ix a ->
Vector ix a
multiplyRowMatrixVector (RowMatrix _cols m) v =
A.fold1 (+) $
A.zipWith (*) (A.map A.snd m) $
Arrange.gather (realIndex m) v
transposeRowMatrix ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
RowMatrix ix a ->
ColumnMatrix ix a
transposeRowMatrix (RowMatrix n x) =
(ColumnMatrix n $ LinAlg.transpose x)
multiplyMatrixMatrix ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
ColumnMatrix ix a ->
RowMatrix ix a ->
Matrix ix a
multiplyMatrixMatrix
(ColumnMatrix rows x) (RowMatrix cols y) =
case matchMatrices x y of
m ->
let global = A.indexTail . A.indexTail . A.indexTail
in Arrange.scatter (+)
(Arrange.mapWithIndex
(\mix tix ->
A.lift $ global mix :. A.fst tix :. A.snd tix) $
A.map A.fst m)
(A.fill (A.lift $ global (A.shape m) :. rows :. cols) 0)
(A.map A.snd m)
matchMatrices ::
(A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
Matrix ix (Int, a) ->
Matrix ix (Int, a) ->
Matrix (ix :. Int) ((Int, Int), a)
matchMatrices x y =
case (matrixShape x, matrixShape y) of
(_ :. xRows :. _xCols, _ :. _yRows :. yCols) ->
A.zipWith
(Exp.modify2 (atom,atom) (atom,atom) $
\(n,xi) (m,yi) -> ((n, m), xi*yi))
(A.replicate (A.lift $ Any :. All :. All :. yCols) x)
(A.replicate (A.lift $ Any :. xRows :. All :. All) y)
scaleRowRows ::
(A.Slice ix, A.Shape ix, A.Elt a, A.IsNum a) =>
Vector ix a -> RowMatrix ix a -> RowMatrix ix a
scaleRowRows s (RowMatrix n x) =
RowMatrix n $
LinAlg.zipScalarVectorWith (\si xi -> Exp.mapSnd (si*) xi) s x