module Matrix.Sparse (
   Matrix,
   bounds,
   fromMap,
   fromRows,
   fromColumns,
   fromDense,
   toRows,
   toColumns,
   toDense,
   getRow,
   getColumn,
   mulVector,
   ) where

import qualified Matrix.Vector as Vector
import qualified Data.Foldable as Fold
import qualified Data.Map as Map
import qualified Data.Array as Array
import Data.Map (Map)
import Data.Array (Array, Ix, accumArray, (!))


data Matrix i j a = Matrix ((i,j), (i,j)) (Map i (Map j a))
   deriving Int -> Matrix i j a -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall i j a.
(Show i, Show j, Show a) =>
Int -> Matrix i j a -> ShowS
forall i j a. (Show i, Show j, Show a) => [Matrix i j a] -> ShowS
forall i j a. (Show i, Show j, Show a) => Matrix i j a -> String
showList :: [Matrix i j a] -> ShowS
$cshowList :: forall i j a. (Show i, Show j, Show a) => [Matrix i j a] -> ShowS
show :: Matrix i j a -> String
$cshow :: forall i j a. (Show i, Show j, Show a) => Matrix i j a -> String
showsPrec :: Int -> Matrix i j a -> ShowS
$cshowsPrec :: forall i j a.
(Show i, Show j, Show a) =>
Int -> Matrix i j a -> ShowS
Show

instance Functor (Matrix i j) where
   fmap :: forall a b. (a -> b) -> Matrix i j a -> Matrix i j b
fmap a -> b
f (Matrix ((i, j), (i, j))
bnds Map i (Map j a)
m) = forall i j a. ((i, j), (i, j)) -> Map i (Map j a) -> Matrix i j a
Matrix ((i, j), (i, j))
bnds forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) Map i (Map j a)
m


bounds :: Matrix i j a -> ((i,j), (i,j))
bounds :: forall i j a. Matrix i j a -> ((i, j), (i, j))
bounds (Matrix ((i, j), (i, j))
bnds Map i (Map j a)
_) = ((i, j), (i, j))
bnds

fromMap :: (Ord i, Ord j) => ((i,j), (i,j)) -> Map (i,j) a -> Matrix i j a
fromMap :: forall i j a.
(Ord i, Ord j) =>
((i, j), (i, j)) -> Map (i, j) a -> Matrix i j a
fromMap ((i, j), (i, j))
bnds =
   forall i j a. ((i, j), (i, j)) -> Map i (Map j a) -> Matrix i j a
Matrix ((i, j), (i, j))
bnds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
Map.fromListWith forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   forall a b. (a -> b) -> [a] -> [b]
map (\((i
i,j
j),a
a) -> (i
i, forall k a. k -> a -> Map k a
Map.singleton j
j a
a)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
Map.toList

fromRows ::
   (Ord i, Ord j) => ((i,j), (i,j)) -> Map i (Map j a) -> Matrix i j a
fromRows :: forall i j a.
(Ord i, Ord j) =>
((i, j), (i, j)) -> Map i (Map j a) -> Matrix i j a
fromRows = forall i j a. ((i, j), (i, j)) -> Map i (Map j a) -> Matrix i j a
Matrix

fromColumns ::
   (Ord i, Ord j) => ((i,j), (i,j)) -> Map j (Map i a) -> Matrix i j a
fromColumns :: forall i j a.
(Ord i, Ord j) =>
((i, j), (i, j)) -> Map j (Map i a) -> Matrix i j a
fromColumns ((i, j), (i, j))
bnds = forall i j a. ((i, j), (i, j)) -> Map i (Map j a) -> Matrix i j a
Matrix ((i, j), (i, j))
bnds forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i j a. (Ord i, Ord j) => Map i (Map j a) -> Map j (Map i a)
flipMap

fromDense :: (Ix i, Ix j) => Array (i,j) a -> Matrix i j a
fromDense :: forall i j a. (Ix i, Ix j) => Array (i, j) a -> Matrix i j a
fromDense Array (i, j) a
a = forall i j a.
(Ord i, Ord j) =>
((i, j), (i, j)) -> Map (i, j) a -> Matrix i j a
fromMap (forall i e. Array i e -> (i, i)
Array.bounds Array (i, j) a
a) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList forall a b. (a -> b) -> a -> b
$ forall i e. Ix i => Array i e -> [(i, e)]
Array.assocs Array (i, j) a
a


toRows :: (Ord i, Ord j) => Matrix i j a -> Map i (Map j a)
toRows :: forall i j a. (Ord i, Ord j) => Matrix i j a -> Map i (Map j a)
toRows (Matrix ((i, j), (i, j))
_bnds Map i (Map j a)
rows) = Map i (Map j a)
rows

toColumns :: (Ord i, Ord j) => Matrix i j a -> Map j (Map i a)
toColumns :: forall i j a. (Ord i, Ord j) => Matrix i j a -> Map j (Map i a)
toColumns (Matrix ((i, j), (i, j))
_bnds Map i (Map j a)
rows) = forall i j a. (Ord i, Ord j) => Map i (Map j a) -> Map j (Map i a)
flipMap Map i (Map j a)
rows

toDense :: (Ix i, Ix j, Num a) => Matrix i j a -> Array (i,j) a
toDense :: forall i j a. (Ix i, Ix j, Num a) => Matrix i j a -> Array (i, j) a
toDense (Matrix ((i, j), (i, j))
bnds Map i (Map j a)
a) =
   forall i e a.
Ix i =>
(e -> a -> e) -> e -> (i, i) -> [(i, a)] -> Array i e
accumArray (forall a b. a -> b -> a
const forall a. a -> a
id) a
0 ((i, j), (i, j))
bnds forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
Fold.fold forall a b. (a -> b) -> a -> b
$
   forall k a b. (k -> a -> b) -> Map k a -> Map k b
Map.mapWithKey (\i
i -> forall a b. (a -> b) -> [a] -> [b]
map (\(j
j,a
e) -> ((i
i,j
j),a
e)) forall b c a. (b -> c) -> (a -> b) -> a -> c
.  forall k a. Map k a -> [(k, a)]
Map.toList) Map i (Map j a)
a


-- cf. comfort-graph:Graph.Comfort.Map.flip
flipMap :: (Ord i, Ord j) => Map i (Map j a) -> Map j (Map i a)
flipMap :: forall i j a. (Ord i, Ord j) => Map i (Map j a) -> Map j (Map i a)
flipMap =
   forall (f :: * -> *) k a.
(Foldable f, Ord k) =>
(a -> a -> a) -> f (Map k a) -> Map k a
Map.unionsWith (forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
Map.unionWith (forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Map.flip: duplicate key")) forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   forall k a. Map k a -> [a]
Map.elems forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a b. (k -> a -> b) -> Map k a -> Map k b
Map.mapWithKey (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. k -> a -> Map k a
Map.singleton)


getRow :: (Ord i, Ord j) => i -> Matrix i j a -> Map j a
getRow :: forall i j a. (Ord i, Ord j) => i -> Matrix i j a -> Map j a
getRow i
i (Matrix ((i, j), (i, j))
_ Map i (Map j a)
rows) = forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault forall k a. Map k a
Map.empty i
i Map i (Map j a)
rows

getColumn :: (Ord i, Ord j) => j -> Matrix i j a -> Map i a
getColumn :: forall i j a. (Ord i, Ord j) => j -> Matrix i j a -> Map i a
getColumn j
j (Matrix ((i, j), (i, j))
_ Map i (Map j a)
rows) = forall a b k. (a -> Maybe b) -> Map k a -> Map k b
Map.mapMaybe (forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup j
j) Map i (Map j a)
rows


mulVector :: (Ix i, Ix j, Num a) => Matrix i j a -> Array j a -> Array i a
mulVector :: forall i j a.
(Ix i, Ix j, Num a) =>
Matrix i j a -> Array j a -> Array i a
mulVector a :: Matrix i j a
a@(Matrix ((i
m0,j
n0), (i
m1,j
n1)) Map i (Map j a)
_) Array j a
v =
   if (j
n0,j
n1) forall a. Eq a => a -> a -> Bool
== forall i e. Array i e -> (i, i)
Array.bounds Array j a
v
     then forall i a. Ix i => (i, i) -> (i -> a) -> Array i a
Vector.generate (i
m0,i
m1) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall j a. (Ix j, Num a) => Map j a -> Array j a -> a
mulRowVector Array j a
v forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall i j a. (Ord i, Ord j) => i -> Matrix i j a -> Map j a
getRow Matrix i j a
a
     else forall a. HasCallStack => String -> a
error String
"Sparse.mulVector: dimensions mismatch"

mulRowVector :: (Ix j, Num a) => Map j a -> Array j a -> a
mulRowVector :: forall j a. (Ix j, Num a) => Map j a -> Array j a -> a
mulRowVector Map j a
row Array j a
v = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Fold.sum forall a b. (a -> b) -> a -> b
$ forall k a b. (k -> a -> b) -> Map k a -> Map k b
Map.mapWithKey (\j
j a
x -> a
x forall a. Num a => a -> a -> a
* Array j a
vforall i e. Ix i => Array i e -> i -> e
!j
j) Map j a
row