{-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Arithmetic.Sparse where import qualified Data.Array.Knead.Arithmetic.LinearAlgebra as LinAlg import Data.Array.Knead.Arithmetic.LinearAlgebra (Vector, Matrix, IOVector, IOMatrix) import qualified Data.Array.Knead.Parameterized.Physical as Phys import qualified Data.Array.Knead.Parameterized.Symbolic as SymP import qualified Data.Array.Knead.Simple.ShapeDependent as ShapeDep import qualified Data.Array.Knead.Simple.Symbolic as Sym import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Core as LLVM import Foreign.Storable (Storable) data Dim set dim = Dim set dim sparseSet :: (Expr.Value val) => val (Dim set dim) -> val set sparseSet = Expr.lift1 $ \(MultiValue.Cons (Dim set _dim)) -> MultiValue.Cons set sparseDim :: (Expr.Value val) => val (Dim set dim) -> val dim sparseDim = Expr.lift1 $ \(MultiValue.Cons (Dim _set dim)) -> MultiValue.Cons dim pairFromDim :: (Expr.Value val) => val (Dim set dim) -> val (set, dim) pairFromDim = Expr.lift1 $ \(MultiValue.Cons (Dim set dim)) -> MultiValue.Cons (set, dim) dimFromPair :: (Expr.Value val) => val (set, dim) -> val (Dim set dim) dimFromPair = Expr.lift1 $ \(MultiValue.Cons (set, dim)) -> MultiValue.Cons (Dim set dim) instance (MultiValue.C set, MultiValue.C dim) => MultiValue.C (Dim set dim) where type Repr f (Dim set dim) = Dim (MultiValue.Repr f set) (MultiValue.Repr f dim) cons (Dim set dim) = dimFromPair $ MultiValue.cons (set, dim) zero = dimFromPair MultiValue.zero undef = dimFromPair MultiValue.undef phis bb = fmap dimFromPair . MultiValue.phis bb . pairFromDim addPhis bb a b = MultiValue.addPhis bb (pairFromDim a) (pairFromDim b) instance (MultiValue.Compose set, MultiValue.Compose dim) => MultiValue.Compose (Dim set dim) where type Composed (Dim set dim) = Dim (MultiValue.Composed set) (MultiValue.Composed dim) compose (Dim set dim) = dimFromPair $ MultiValue.compose (set,dim) instance (Expr.Compose set, Expr.Compose dim) => Expr.Compose (Dim set dim) where type Composed (Dim set dim) = Dim (Expr.Composed set) (Expr.Composed dim) compose (Dim set dim) = dimFromPair $ Expr.compose (set,dim) instance (Shape.C set, Shape.C dim) => Shape.C (Dim set dim) where type Index (Dim set dim) = Shape.Index set {- Not really useful. Only intended for the case that all dimensions match. -} intersectCode s0@(MultiValue.Cons (Dim _set0 _dim)) s1@(MultiValue.Cons (Dim _set1 dim)) = do MultiValue.Cons set <- Shape.intersectCode (sparseSet s0) (sparseSet s1) return $ MultiValue.Cons $ Dim set dim sizeCode sh = Shape.sizeCode $ sparseSet sh size (Dim set _dim) = Shape.size set flattenIndexRec sh ix = Shape.flattenIndexRec (sparseSet sh) ix loop f = Shape.loop f . sparseSet {- | Sparse matrix with a definite number of non-zero entries per row. -} newtype RowMatrix p coll rows set cols a = RowMatrix (Matrix p coll rows (Dim set cols) (Shape.Index cols, a)) multiplyRowMatrixVector :: (Shape.C coll, Shape.C rows, Shape.C cols, Shape.C set, MultiValue.PseudoRing a) => RowMatrix p coll rows set cols a -> Vector p coll cols a -> Vector p coll rows a multiplyRowMatrixVector (RowMatrix m) v = Sym.fold1 Expr.add $ ShapeDep.backpermute LinAlg.balanceLeft LinAlg.balanceRight $ Sym.zipWith Expr.mul (Sym.map Expr.snd m) $ Sym.gather (sparseRealIndex m) v {- | Sparse matrix with a definite number of non-zero entries per column. -} newtype ColumnMatrix p coll set rows cols a = ColumnMatrix (Matrix p coll (Dim set rows) cols (Shape.Index rows, a)) multiplyColumnMatrixVector :: (Shape.C coll, Shape.C set, Shape.C rows, Shape.C cols, MultiValue.PseudoRing a, MultiValueMemory.C a, Storable a, MultiValueMemory.C rows, Storable rows, MultiValueMemory.C coll, Storable coll, MultiValueMemory.Struct rows ~ rowsStruct, LLVM.IsSized rowsStruct, MultiValueMemory.Struct coll ~ collStruct, LLVM.IsSized collStruct) => ColumnMatrix p coll set rows cols a -> Vector p coll cols a -> IOVector p coll rows a multiplyColumnMatrixVector (ColumnMatrix m) v = Phys.scatter Expr.add (ShapeDep.fill (Expr.modify (atom,(atom,atom)) $ \(coll, (row,_col)) -> (coll, sparseDim row)) Expr.zero m) $ Sym.mapWithIndex (Expr.modify2 (atom,atom) (atom,atom) $ \(coll,_ix) (i,a) -> ((coll,i),a)) $ ShapeDep.backpermute2 (\msh _ -> msh) id (Expr.modify (atom,(atom,atom)) $ \(coll, (_row,col)) -> (coll,col)) -- different from mulCell for non-commutative multiplications (Expr.modify2 (atom,atom) atom $ \(i,a) b -> (i, Expr.mul a b)) m v multiplyMatrixMatrix :: (Shape.C coll, Shape.C set0, Shape.C set1, Shape.C rows, Shape.C cols, Shape.C glue, MultiValue.PseudoRing a, MultiValueMemory.C a, Storable a, MultiValueMemory.C cols, Storable cols, MultiValueMemory.C rows, Storable rows, MultiValueMemory.C coll, Storable coll, MultiValueMemory.Struct cols ~ colsStruct, LLVM.IsSized colsStruct, MultiValueMemory.Struct rows ~ rowsStruct, LLVM.IsSized rowsStruct, MultiValueMemory.Struct coll ~ collStruct, LLVM.IsSized collStruct) => ColumnMatrix p coll set0 rows glue a -> RowMatrix p coll glue set1 cols a -> IOMatrix p coll rows cols a multiplyMatrixMatrix sx@(ColumnMatrix x) sy@(RowMatrix y) = Phys.scatter Expr.add (fillMatrixMatrix Expr.zero sx sy) $ Sym.mapWithIndex (Expr.modify2 (atom,atom) (atom,atom) $ \(coll,_ix) (i,a) -> ((coll,i),a)) $ ShapeDep.backpermute2 (Expr.modify2 (atom,(atom,atom)) (atom,(atom,atom)) $ \(coll, (rows,glues)) (_coll, (_glues,cols)) -> (coll, (rows, glues, cols))) (Expr.modify (atom,(atom,atom,atom)) $ \(coll, (row, glue, _col)) -> (coll, (row, glue))) (Expr.modify (atom,(atom,atom,atom)) $ \(coll, (_row, glue, col)) -> (coll, (glue, col))) (Expr.modify2 (atom,atom) (atom,atom) $ \(i,a) (j,b) -> ((i,j), Expr.mul a b)) x y fillMatrixMatrix :: (Shape.C coll, Shape.C rows, Shape.C cols, MultiValue.C a) => Exp a -> ColumnMatrix p coll set0 rows glue a -> RowMatrix p coll glue set1 cols a -> SymP.Array p (coll, (rows, cols)) a fillMatrixMatrix a (ColumnMatrix x) (RowMatrix y) = ShapeDep.backpermute2 (Expr.modify2 (atom,atom) atom $ \(coll,row) col -> (coll, (row,col))) (Expr.modify (atom,(atom,atom)) $ \(coll, (row,_col)) -> (coll, row)) (Expr.modify (atom,(atom,atom)) $ \(_coll, (_row,col)) -> col) asTypeOf (ShapeDep.fill (Expr.modify (atom,(atom,atom)) $ \(coll, (row,_col)) -> (coll, sparseDim row)) a x) (ShapeDep.fill (Expr.modify (atom,(atom,atom)) $ \(_coll, (_row,col)) -> sparseDim col) a y) transposeColumnMatrix :: (Shape.C coll, Shape.C set, Shape.C rows, Shape.C cols) => ColumnMatrix p coll set rows cols a -> RowMatrix p coll cols set rows a transposeColumnMatrix (ColumnMatrix x) = RowMatrix $ LinAlg.transpose x sparseRealIndex :: (Shape.C coll, Shape.C rows, Shape.C cols) => Matrix p coll rows cols (i, a) -> Matrix p coll rows cols (Shape.Index coll, i) sparseRealIndex = Sym.mapWithIndex (Expr.modify2 (atom,atom) (atom,atom) $ \(coll,_ix) (i,_a) -> (coll, i)) scaleRowRows :: (Shape.C coll, Shape.C rows, Shape.C cols, Shape.C set, MultiValue.PseudoRing a) => Vector p coll rows a -> RowMatrix p coll rows set cols a -> RowMatrix p coll rows set cols a scaleRowRows v (RowMatrix m) = RowMatrix $ ShapeDep.backpermute2 (flip const) (Expr.mapSnd Expr.fst) id mulCell v m scaleRowColumns :: (Shape.C coll, Shape.C rows, Shape.C cols, Shape.C set, MultiValue.PseudoRing a) => Vector p coll cols a -> RowMatrix p coll rows set cols a -> RowMatrix p coll rows set cols a scaleRowColumns v (RowMatrix m) = RowMatrix $ Sym.zipWith (flip mulCell) m $ Sym.gather (Sym.mapWithIndex (Expr.modify2 (atom,atom) (atom,atom) $ \(coll, _rowCol) (i,_y) -> (coll,i)) m) v mulCell :: (MultiValue.PseudoRing a) => Exp a -> Exp (i, a) -> Exp (i, a) mulCell = Expr.modify2 atom (atom, atom) $ \x (i,y) -> (i, Expr.mul x y)