{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.Arithmetic.Sparse where

import qualified Data.Array.Accelerate.Arithmetic.LinearAlgebra as LinAlg
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate.Utility.Lift.Exp (atom, )

import Data.Array.Accelerate.Arithmetic.LinearAlgebra
          (Matrix, Vector, matrixShape, )
import Data.Array.Accelerate
          (Exp, Any(Any), All(All), (:.)((:.)), )


{- |
Sparse matrix with a definite number of non-zero entries per column.
-}
data ColumnMatrix ix a =
        ColumnMatrix {numRows :: Exp Int, columnMatrix :: Matrix ix (Int, a)}

realIndex ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   Matrix ix (Int, a) ->
   Matrix ix (ix :. Int)
realIndex m =
   A.zipWith Exp.indexCons
      (A.generate (A.shape m) (A.indexTail . A.indexTail))
      (A.map A.fst m)

multiplyColumnMatrixVector ::
   (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
   ColumnMatrix ix a ->
   Vector ix a ->
   Vector ix a
multiplyColumnMatrixVector (ColumnMatrix rows m) v =
   Arrange.scatter (+)
      (realIndex m)
      (case matrixShape m of
          sh :. _rows :. _cols -> A.fill (A.lift $ sh :. rows) 0) $
   A.zipWith (*)
      (A.map A.snd m)
      (A.replicate (A.lift $ Any :. LinAlg.numRows m :. All) v)

transposeColumnMatrix ::
   (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
   ColumnMatrix ix a ->
   RowMatrix ix a
transposeColumnMatrix (ColumnMatrix n x) =
   RowMatrix n $ LinAlg.transpose x


{- |
Sparse matrix with a definite number of non-zero entries per row.
-}
data RowMatrix ix a =
        RowMatrix {numCols :: Exp Int, rowMatrix :: Matrix ix (Int, a)}

multiplyRowMatrixVector ::
   (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
   RowMatrix ix a ->
   Vector ix a ->
   Vector ix a
multiplyRowMatrixVector (RowMatrix _cols m) v =
   A.fold1 (+) $
   A.zipWith (*) (A.map A.snd m) $
   Arrange.gather (realIndex m) v

transposeRowMatrix ::
   (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
   RowMatrix ix a ->
   ColumnMatrix ix a
transposeRowMatrix (RowMatrix n x) =
   (ColumnMatrix n $ LinAlg.transpose x)

multiplyMatrixMatrix ::
   (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
   ColumnMatrix ix a ->
   RowMatrix ix a ->
   Matrix ix a
multiplyMatrixMatrix
      (ColumnMatrix rows x) (RowMatrix cols y) =
   case matchMatrices x y of
      m ->
         let global = A.indexTail . A.indexTail . A.indexTail
         in  Arrange.scatter (+)
                (Arrange.mapWithIndex
                   (\mix tix ->
                      A.lift $ global mix :. A.fst tix :. A.snd tix) $
                 A.map A.fst m)
                (A.fill (A.lift $ global (A.shape m) :. rows :. cols) 0)
                (A.map A.snd m)

matchMatrices ::
   (A.Shape ix, A.Slice ix, A.IsNum a, A.Elt a) =>
   Matrix ix (Int, a) ->
   Matrix ix (Int, a) ->
   Matrix (ix :. Int) ((Int, Int), a)
matchMatrices x y =
   case (matrixShape x, matrixShape y) of
      (_ :. xRows :. _xCols, _ :. _yRows :. yCols) ->
         -- it must be xCols == yRows
         A.zipWith
            (Exp.modify2 (atom,atom) (atom,atom) $
             \(n,xi) (m,yi) -> ((n, m), xi*yi))
            (A.replicate (A.lift $ Any :. All :. All :. yCols) x)
            (A.replicate (A.lift $ Any :. xRows :. All :. All) y)


scaleRowRows ::
   (A.Slice ix, A.Shape ix, A.Elt a, A.IsNum a) =>
   Vector ix a -> RowMatrix ix a -> RowMatrix ix a
scaleRowRows s (RowMatrix n x) =
   RowMatrix n $
   LinAlg.zipScalarVectorWith (\si xi -> Exp.mapSnd (si*) xi) s x