{-# LANGUAGE TypeOperators #-} module Data.Array.Accelerate.Arithmetic.Sparse where import qualified Data.Array.Accelerate.Arithmetic.LinearAlgebra as LinAlg import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp import qualified Data.Array.Accelerate.Utility.Arrange as Arrange import qualified Data.Array.Accelerate as A import Data.Array.Accelerate.Utility.Lift.Exp (expr, ) import Data.Array.Accelerate.Arithmetic.LinearAlgebra (Matrix, Vector, matrixShape, ) import Data.Array.Accelerate (Exp, Any(Any), All(All), (:.)((:.)), ) {- | Sparse matrix with a definite number of non-zero entries per column. -} data ColumnMatrix ix a = ColumnMatrix {numRows :: Exp Int, columnMatrix :: Matrix ix (Int, a)} realIndex :: (A.Shape ix, A.Slice ix, A.Elt a) => Matrix ix (Int, a) -> Matrix ix (ix :. Int) realIndex m = A.zipWith Exp.indexCons (A.generate (A.shape m) (A.indexTail . A.indexTail)) (A.map A.fst m) multiplyColumnMatrixVector :: (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) => ColumnMatrix ix a -> Vector ix a -> Vector ix a multiplyColumnMatrixVector (ColumnMatrix rows m) v = Arrange.scatter (+) (realIndex m) (case matrixShape m of sh :. _rows :. _cols -> A.fill (A.lift $ sh :. rows) 0) $ A.zipWith (*) (A.map A.snd m) (A.replicate (A.lift $ Any :. LinAlg.numRows m :. All) v) transposeColumnMatrix :: (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) => ColumnMatrix ix a -> RowMatrix ix a transposeColumnMatrix (ColumnMatrix n x) = RowMatrix n $ LinAlg.transpose x {- | Sparse matrix with a definite number of non-zero entries per row. -} data RowMatrix ix a = RowMatrix {numCols :: Exp Int, rowMatrix :: Matrix ix (Int, a)} multiplyRowMatrixVector :: (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) => RowMatrix ix a -> Vector ix a -> Vector ix a multiplyRowMatrixVector (RowMatrix _cols m) v = A.fold1 (+) $ A.zipWith (*) (A.map A.snd m) $ Arrange.gather (realIndex m) v transposeRowMatrix :: (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) => RowMatrix ix a -> ColumnMatrix ix a transposeRowMatrix (RowMatrix n x) = (ColumnMatrix n $ LinAlg.transpose x) multiplyMatrixMatrix :: (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) => ColumnMatrix ix a -> RowMatrix ix a -> Matrix ix a multiplyMatrixMatrix (ColumnMatrix rows x) (RowMatrix cols y) = case matchMatrices x y of m -> let global = A.indexTail . A.indexTail . A.indexTail in Arrange.scatter (+) (Arrange.mapWithIndex (\mix tix -> A.lift $ global mix :. A.fst tix :. A.snd tix) $ A.map A.fst m) (A.fill (A.lift $ global (A.shape m) :. rows :. cols) 0) (A.map A.snd m) matchMatrices :: (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) => Matrix ix (Int, a) -> Matrix ix (Int, a) -> Matrix (ix :. Int) ((Int, Int), a) matchMatrices x y = case (matrixShape x, matrixShape y) of (_ :. xRows :. _xCols, _ :. _yRows :. yCols) -> -- it must be xCols == yRows A.zipWith (Exp.modify2 (expr,expr) (expr,expr) $ \(n,xi) (m,yi) -> ((n, m), xi*yi)) (A.replicate (A.lift $ Any :. All :. All :. yCols) x) (A.replicate (A.lift $ Any :. xRows :. All :. All) y) scaleRowRows :: (A.Slice ix, A.Shape ix, A.Elt a, A.IsNum a) => Vector ix a -> RowMatrix ix a -> RowMatrix ix a scaleRowRows s (RowMatrix n x) = RowMatrix n $ LinAlg.zipScalarVectorWith (\si xi -> Exp.mapSnd (si*) xi) s x