-----------------------------------------------------------------------------
-- |
-- Module      :  DSP.Matrix.Matrix
-- Copyright   :  (c) Matthew Donadio 2003
-- License     :  GPL
--
-- Maintainer  :  m.p.donadio@ieee.org
-- Stability   :  experimental
-- Portability :  portable
--
-- Basic matrix routines
--
-----------------------------------------------------------------------------

module Matrix.Matrix where

import Matrix.Vector (generate)
import Data.Array
import Data.Complex

-- | Matrix-matrix multiplication: A x B = C

mm_mult :: (Ix i, Ix j, Ix k, Num a) => Array (i,j) a -- ^ A
	-> Array (j,k) a -- ^ B
	-> Array (i,k) a -- ^ C

mm_mult :: forall i j k a.
(Ix i, Ix j, Ix k, Num a) =>
Array (i, j) a -> Array (j, k) a -> Array (i, k) a
mm_mult Array (i, j) a
a Array (j, k) a
b = if (j
ac0,j
ac1) forall a. Eq a => a -> a -> Bool
/= (j
br0,j
br1)
	      then forall a. HasCallStack => [Char] -> a
error [Char]
"mm_mult: inside dimensions inconsistent"
	      else forall i a. Ix i => (i, i) -> (i -> a) -> Array i a
generate ((i
ar0,k
bc0),(i
ar1,k
bc1)) forall a b. (a -> b) -> a -> b
$ \(i
i,k
j) ->
			forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (i, j) a
aforall i e. Ix i => Array i e -> i -> e
!(i
i,j
k) forall a. Num a => a -> a -> a
* Array (j, k) a
bforall i e. Ix i => Array i e -> i -> e
!(j
k,k
j) | j
k <- forall a. Ix a => (a, a) -> [a]
range (j
ac0,j
ac1) ]
    where ((i
ar0,j
ac0),(i
ar1,j
ac1)) = forall i e. Array i e -> (i, i)
bounds Array (i, j) a
a
	  ((j
br0,k
bc0),(j
br1,k
bc1)) = forall i e. Array i e -> (i, i)
bounds Array (j, k) a
b

-- | Matrix-vector multiplication: A x b = c

mv_mult :: (Ix i, Ix j, Num a) => Array (i,j) a -- ^ A
	-> Array j a -- ^ b
	-> Array i a -- ^ c

mv_mult :: forall i j a.
(Ix i, Ix j, Num a) =>
Array (i, j) a -> Array j a -> Array i a
mv_mult Array (i, j) a
a Array j a
b = if (j
ac0,j
ac1) forall a. Eq a => a -> a -> Bool
/= forall i e. Array i e -> (i, i)
bounds Array j a
b
	      then forall a. HasCallStack => [Char] -> a
error [Char]
"mv_mult: dimensions inconsistent"
	      else forall i a. Ix i => (i, i) -> (i -> a) -> Array i a
generate (i
ar0,i
ar1) forall a b. (a -> b) -> a -> b
$ \i
i ->
			forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (i, j) a
aforall i e. Ix i => Array i e -> i -> e
!(i
i,j
k) forall a. Num a => a -> a -> a
* a
bk | (j
k,a
bk) <- forall i e. Ix i => Array i e -> [(i, e)]
assocs Array j a
b ]
    where ((i
ar0,j
ac0),(i
ar1,j
ac1)) = forall i e. Array i e -> (i, i)
bounds Array (i, j) a
a

-- | Transpose of a matrix

m_trans :: (Ix i, Ix j, Num a) => Array (i,j) a -- ^ A
	-> Array (j,i) a -- ^ A^T

m_trans :: forall i j a.
(Ix i, Ix j, Num a) =>
Array (i, j) a -> Array (j, i) a
m_trans Array (i, j) a
a = forall i a. Ix i => (i, i) -> (i -> a) -> Array i a
generate ((j
n0,i
m0),(j
n1,i
m1)) forall a b. (a -> b) -> a -> b
$ \(j
i,i
j) -> Array (i, j) a
aforall i e. Ix i => Array i e -> i -> e
!(i
j,j
i)
    where ((i
m0,j
n0),(i
m1,j
n1)) = forall i e. Array i e -> (i, i)
bounds Array (i, j) a
a

-- | Hermitian transpose (conjugate transpose) of a matrix

m_hermit :: (Ix i, Ix j, RealFloat a) => Array (i,j) (Complex a) -- ^ A
	 -> Array (j,i) (Complex a) -- ^ A^H

m_hermit :: forall i j a.
(Ix i, Ix j, RealFloat a) =>
Array (i, j) (Complex a) -> Array (j, i) (Complex a)
m_hermit Array (i, j) (Complex a)
a = forall i a. Ix i => (i, i) -> (i -> a) -> Array i a
generate ((j
n0,i
m0),(j
n1,i
m1)) forall a b. (a -> b) -> a -> b
$ \(j
i,i
j) -> forall a. Num a => Complex a -> Complex a
conjugate (Array (i, j) (Complex a)
aforall i e. Ix i => Array i e -> i -> e
!(i
j,j
i))
    where ((i
m0,j
n0),(i
m1,j
n1)) = forall i e. Array i e -> (i, i)
bounds Array (i, j) (Complex a)
a


columnBounds :: (Ix i, Ix j) => Array (i,j) a -> (i,i)
columnBounds :: forall i j a. (Ix i, Ix j) => Array (i, j) a -> (i, i)
columnBounds Array (i, j) a
a =
   let ((i
m0,j
_n0), (i
m1,j
_n1)) = forall i e. Array i e -> (i, i)
bounds Array (i, j) a
a
   in  (i
m0,i
m1)

rowBounds :: (Ix i, Ix j) => Array (i,j) a -> (j,j)
rowBounds :: forall i j a. (Ix i, Ix j) => Array (i, j) a -> (j, j)
rowBounds Array (i, j) a
a =
   let ((i
_m0,j
n0), (i
_m1,j
n1)) = forall i e. Array i e -> (i, i)
bounds Array (i, j) a
a
   in  (j
n0,j
n1)

getColumn :: (Ix i, Ix j) => j -> Array (i,j) e -> Array i e
getColumn :: forall i j e. (Ix i, Ix j) => j -> Array (i, j) e -> Array i e
getColumn j
j Array (i, j) e
a = forall i j e.
(Ix i, Ix j) =>
(i, i) -> (i -> j) -> Array j e -> Array i e
ixmap (forall i j a. (Ix i, Ix j) => Array (i, j) a -> (i, i)
columnBounds Array (i, j) e
a) (\i
k -> (i
k,j
j)) Array (i, j) e
a

getRow :: (Ix i, Ix j) => i -> Array (i,j) e -> Array j e
getRow :: forall i j e. (Ix i, Ix j) => i -> Array (i, j) e -> Array j e
getRow i
k Array (i, j) e
a = forall i j e.
(Ix i, Ix j) =>
(i, i) -> (i -> j) -> Array j e -> Array i e
ixmap (forall i j a. (Ix i, Ix j) => Array (i, j) a -> (j, j)
rowBounds Array (i, j) e
a) (\j
j -> (i
k,j
j)) Array (i, j) e
a

toColumns :: (Ix i, Ix j) => Array (i,j) a -> [Array i a]
toColumns :: forall i j a. (Ix i, Ix j) => Array (i, j) a -> [Array i a]
toColumns Array (i, j) a
a = forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall i j e. (Ix i, Ix j) => j -> Array (i, j) e -> Array i e
getColumn Array (i, j) a
a) forall a b. (a -> b) -> a -> b
$ forall a. Ix a => (a, a) -> [a]
range forall a b. (a -> b) -> a -> b
$ forall i j a. (Ix i, Ix j) => Array (i, j) a -> (j, j)
rowBounds Array (i, j) a
a

toRows :: (Ix i, Ix j) => Array (i,j) a -> [Array j a]
toRows :: forall i j a. (Ix i, Ix j) => Array (i, j) a -> [Array j a]
toRows Array (i, j) a
a = forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall i j e. (Ix i, Ix j) => i -> Array (i, j) e -> Array j e
getRow Array (i, j) a
a) forall a b. (a -> b) -> a -> b
$ forall a. Ix a => (a, a) -> [a]
range forall a b. (a -> b) -> a -> b
$ forall i j a. (Ix i, Ix j) => Array (i, j) a -> (i, i)
columnBounds Array (i, j) a
a


{- |
We need the bounds of the row indices for empty input lists.
-}
fromColumns :: (Ix i) => (i,i) -> [Array i a] -> Array (i,Int) a
fromColumns :: forall i a. Ix i => (i, i) -> [Array i a] -> Array (i, Int) a
fromColumns bnds :: (i, i)
bnds@(i
m0,i
m1) [Array i a]
columns =
   if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (((i, i)
bndsforall a. Eq a => a -> a -> Bool
==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i e. Array i e -> (i, i)
bounds) [Array i a]
columns
     then forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((i
m0,Int
0), (i
m1, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Array i a]
columns forall a. Num a => a -> a -> a
- Int
1)) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
          forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Int
k -> forall a b. (a -> b) -> [a] -> [b]
map (\(i
i,a
a) -> ((i
i,Int
k),a
a)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i e. Ix i => Array i e -> [(i, e)]
assocs)
            [Int
0..] [Array i a]
columns
     else forall a. HasCallStack => [Char] -> a
error [Char]
"Matrix.fromColumns: column bounds mismatch"

fromRows :: (Ix j) => (j,j) -> [Array j a] -> Array (Int,j) a
fromRows :: forall j a. Ix j => (j, j) -> [Array j a] -> Array (Int, j) a
fromRows bnds :: (j, j)
bnds@(j
n0,j
n1) [Array j a]
rows =
   if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (((j, j)
bndsforall a. Eq a => a -> a -> Bool
==) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i e. Array i e -> (i, i)
bounds) [Array j a]
rows
     then forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((Int
0,j
n0), (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Array j a]
rows forall a. Num a => a -> a -> a
- Int
1, j
n1)) forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$
          forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Int
k -> forall a b. (a -> b) -> [a] -> [b]
map (\(j
i,a
a) -> ((Int
k,j
i),a
a)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall i e. Ix i => Array i e -> [(i, e)]
assocs)
            [Int
0..] [Array j a]
rows
     else forall a. HasCallStack => [Char] -> a
error [Char]
"Matrix.fromRows: row bounds mismatch"



outer :: (Ix i, Ix j, Num a) => Array i a -> Array j a -> Array (i,j) a
outer :: forall i j a.
(Ix i, Ix j, Num a) =>
Array i a -> Array j a -> Array (i, j) a
outer Array i a
x Array j a
y =
   let (i
m0,i
m1) = forall i e. Array i e -> (i, i)
bounds Array i a
x
       (j
n0,j
n1) = forall i e. Array i e -> (i, i)
bounds Array j a
y
   in  forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((i
m0,j
n0), (i
m1,j
n1)) forall a b. (a -> b) -> a -> b
$ do
         (i
i,a
xi) <- forall i e. Ix i => Array i e -> [(i, e)]
assocs Array i a
x
         (j
j,a
yj) <- forall i e. Ix i => Array i e -> [(i, e)]
assocs Array j a
y
         forall (m :: * -> *) a. Monad m => a -> m a
return ((i
i,j
j), a
xiforall a. Num a => a -> a -> a
*a
yj)

inner :: (Ix i, Num a) => Array i a -> Array i a -> a
inner :: forall i a. (Ix i, Num a) => Array i a -> Array i a -> a
inner Array i a
x Array i a
y =
   if forall i e. Array i e -> (i, i)
bounds Array i a
x forall a. Eq a => a -> a -> Bool
== forall i e. Array i e -> (i, i)
bounds Array i a
y
     then forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) (forall i e. Array i e -> [e]
elems Array i a
x) (forall i e. Array i e -> [e]
elems Array i a
y)
     else forall a. HasCallStack => [Char] -> a
error [Char]
"inner: dimensions mismatch"