module Matrix.Sparse ( Matrix, bounds, fromMap, fromRows, fromColumns, fromDense, toRows, toColumns, toDense, getRow, getColumn, mulVector, ) where import qualified Matrix.Vector as Vector import qualified Data.Foldable as Fold import qualified Data.Map as Map import qualified Data.Array as Array import Data.Map (Map) import Data.Array (Array, Ix, accumArray, (!)) data Matrix i j a = Matrix ((i,j), (i,j)) (Map i (Map j a)) deriving Show instance Functor (Matrix i j) where fmap f (Matrix bnds m) = Matrix bnds $ fmap (fmap f) m bounds :: Matrix i j a -> ((i,j), (i,j)) bounds (Matrix bnds _) = bnds fromMap :: (Ord i, Ord j) => ((i,j), (i,j)) -> Map (i,j) a -> Matrix i j a fromMap bnds = Matrix bnds . Map.fromListWith Map.union . map (\((i,j),a) -> (i, Map.singleton j a)) . Map.toList fromRows :: (Ord i, Ord j) => ((i,j), (i,j)) -> Map i (Map j a) -> Matrix i j a fromRows = Matrix fromColumns :: (Ord i, Ord j) => ((i,j), (i,j)) -> Map j (Map i a) -> Matrix i j a fromColumns bnds = Matrix bnds . flipMap fromDense :: (Ix i, Ix j) => Array (i,j) a -> Matrix i j a fromDense a = fromMap (Array.bounds a) $ Map.fromList $ Array.assocs a toRows :: (Ord i, Ord j) => Matrix i j a -> Map i (Map j a) toRows (Matrix _bnds rows) = rows toColumns :: (Ord i, Ord j) => Matrix i j a -> Map j (Map i a) toColumns (Matrix _bnds rows) = flipMap rows toDense :: (Ix i, Ix j, Num a) => Matrix i j a -> Array (i,j) a toDense (Matrix bnds a) = accumArray (const id) 0 bnds $ Fold.fold $ Map.mapWithKey (\i -> map (\(j,e) -> ((i,j),e)) . Map.toList) a -- cf. comfort-graph:Graph.Comfort.Map.flip flipMap :: (Ord i, Ord j) => Map i (Map j a) -> Map j (Map i a) flipMap = Map.unionsWith (Map.unionWith (error $ "Map.flip: duplicate key")) . Map.elems . Map.mapWithKey (fmap . Map.singleton) getRow :: (Ord i, Ord j) => i -> Matrix i j a -> Map j a getRow i (Matrix _ rows) = Map.findWithDefault Map.empty i rows getColumn :: (Ord i, Ord j) => j -> Matrix i j a -> Map i a getColumn j (Matrix _ rows) = Map.mapMaybe (Map.lookup j) rows mulVector :: (Ix i, Ix j, Num a) => Matrix i j a -> Array j a -> Array i a mulVector a@(Matrix ((m0,n0), (m1,n1)) _) v = if (n0,n1) == Array.bounds v then Vector.generate (m0,m1) $ flip mulRowVector v . flip getRow a else error "Sparse.mulVector: dimensions mismatch" mulRowVector :: (Ix j, Num a) => Map j a -> Array j a -> a mulRowVector row v = Fold.sum $ Map.mapWithKey (\j x -> x * v!j) row