-----------------------------------------------------------------------------
-- |
-- Module      :  DSP.Matrix.LU
-- Copyright   :  (c) Matthew Donadio 2003
-- License     :  GPL
--
-- Maintainer  :  m.p.donadio@ieee.org
-- Stability   :  experimental
-- Portability :  portable
--
-- Module implementing LU decomposition and related functions
--
-----------------------------------------------------------------------------

module Matrix.LU (lu, lu_solve, improve, inverse, lu_det, solve, det) where

import qualified Matrix.Matrix as Matrix
import qualified Matrix.Vector as Vector
import qualified Data.List as List
import Data.Array
import Data.Ord (comparing)


-- | LU decomposition via Crout's Algorithm

-- TODO: modify for partial pivoting / permutation matrix
-- TODO: add singularity check

-- I am sure these are in G&VL, but the two cases of function f below are
-- formulas (2.3.13) and (2.3.12) from NRIC with some variable renaming

lu :: Array (Int,Int) Double -- ^ A
   -> Array (Int,Int) Double -- ^ LU(A)

lu :: Array (Int, Int) Double -> Array (Int, Int) Double
lu Array (Int, Int) Double
a = Array (Int, Int) Double
a'
    where a' :: Array (Int, Int) Double
a' = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array ((Int, Int), (Int, Int))
bnds [ ((Int
i,Int
j), Int -> Int -> Double
luij Int
i Int
j) | (Int
i,Int
j) <- forall a. Ix a => (a, a) -> [a]
range ((Int, Int), (Int, Int))
bnds ]
          luij :: Int -> Int -> Double
luij Int
i Int
j =
             if Int
iforall a. Ord a => a -> a -> Bool
>Int
j
               then (Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
j) forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (Int, Int) Double
a'forall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
k) forall a. Num a => a -> a -> a
* Array (Int, Int) Double
a'forall i e. Ix i => Array i e -> i -> e
!(Int
k,Int
j) | Int
k <- [Int
1 ..(Int
jforall a. Num a => a -> a -> a
-Int
1)] ]) forall a. Fractional a => a -> a -> a
/ Array (Int, Int) Double
a'forall i e. Ix i => Array i e -> i -> e
!(Int
j,Int
j)
               else  Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
j) forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (Int, Int) Double
a'forall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
k) forall a. Num a => a -> a -> a
* Array (Int, Int) Double
a'forall i e. Ix i => Array i e -> i -> e
!(Int
k,Int
j) | Int
k <- [Int
1 ..(Int
iforall a. Num a => a -> a -> a
-Int
1)] ]
          bnds :: ((Int, Int), (Int, Int))
bnds = forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a

-- | Solution to Ax=b via LU decomposition

-- forward is formula (2.3.6) in NRIC, but remembering that a11=1
-- backward is formula (2.3.7) in NRIC

lu_solve :: Array (Int,Int) Double -- ^ LU(A)
         -> Array Int Double -- ^ b
         -> Array Int Double -- ^ x

lu_solve :: Array (Int, Int) Double -> Array Int Double -> Array Int Double
lu_solve Array (Int, Int) Double
a Array Int Double
b = Array Int Double
x
    where x :: Array Int Double
x = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int
1,Int
n) ([(Int
n,Double
xn)] forall a. [a] -> [a] -> [a]
++ [ (Int
i, Int -> Double
backward Int
i) | Int
i <- (forall a. [a] -> [a]
reverse [Int
1..(Int
nforall a. Num a => a -> a -> a
-Int
1)]) ])
          y :: Array Int Double
y = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int
1,Int
n) ([(Int
1,Double
y1)] forall a. [a] -> [a] -> [a]
++ [ (Int
i, Int -> Double
forward Int
i)  | Int
i <- [Int
2..Int
n] ])
          y1 :: Double
y1         = Array Int Double
bforall i e. Ix i => Array i e -> i -> e
!Int
1
          forward :: Int -> Double
forward  Int
i = (Array Int Double
bforall i e. Ix i => Array i e -> i -> e
!Int
i forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
j) forall a. Num a => a -> a -> a
* Array Int Double
yforall i e. Ix i => Array i e -> i -> e
!Int
j | Int
j <- [Int
1..(Int
iforall a. Num a => a -> a -> a
-Int
1)] ])
          xn :: Double
xn         = Array Int Double
yforall i e. Ix i => Array i e -> i -> e
!Int
n forall a. Fractional a => a -> a -> a
/ Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
n,Int
n)
          backward :: Int -> Double
backward Int
i = (Array Int Double
yforall i e. Ix i => Array i e -> i -> e
!Int
i forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
j) forall a. Num a => a -> a -> a
* Array Int Double
xforall i e. Ix i => Array i e -> i -> e
!Int
j | Int
j <- [(Int
iforall a. Num a => a -> a -> a
+Int
1)..Int
n] ]) forall a. Fractional a => a -> a -> a
/ Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
i)
          ((Int
_,Int
_),(Int
n,Int
_)) = forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a

-- | Improve a solution to Ax=b via LU decomposition

-- formula (2.7.4) from NRIC

improve :: Array (Int,Int) Double -- ^ A
        -> Array (Int,Int) Double -- ^ LU(A)
        -> Array Int Double -- ^ b
        -> Array Int Double -- ^ x
        -> Array Int Double -- ^ x'

improve :: Array (Int, Int) Double
-> Array (Int, Int) Double
-> Array Int Double
-> Array Int Double
-> Array Int Double
improve Array (Int, Int) Double
a Array (Int, Int) Double
a_lu Array Int Double
b Array Int Double
x = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int
1,Int
n) [ (Int
i, Array Int Double
xforall i e. Ix i => Array i e -> i -> e
!Int
i forall a. Num a => a -> a -> a
- Array Int Double
errforall i e. Ix i => Array i e -> i -> e
!Int
i) | Int
i <- [Int
1..Int
n] ]
    where err :: Array Int Double
err = Array (Int, Int) Double -> Array Int Double -> Array Int Double
lu_solve Array (Int, Int) Double
a_lu Array Int Double
rhs
          rhs :: Array Int Double
rhs = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (Int
1,Int
n) [ (Int
i, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [ Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
j) forall a. Num a => a -> a -> a
* Array Int Double
xforall i e. Ix i => Array i e -> i -> e
!Int
j | Int
j <- [Int
1..Int
n] ] forall a. Num a => a -> a -> a
- Array Int Double
bforall i e. Ix i => Array i e -> i -> e
!Int
i) | Int
i <- [Int
1..Int
n] ]
          ((Int
_,Int
_),(Int
n,Int
_)) = forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a

-- | Matrix inversion via LU decomposition

-- Section (2.4) from NRIC

-- TODO: build in improve

inverse :: Array (Int,Int) Double -- ^ A
        -> Array (Int,Int) Double -- ^ A^-1

inverse :: Array (Int, Int) Double -> Array (Int, Int) Double
inverse Array (Int, Int) Double
a0 = Array (Int, Int) Double
a'
    where a' :: Array (Int, Int) Double
a' = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a0) (forall {a} {t} {b}.
(Ix a, Num t) =>
[Array a b] -> t -> [((a, t), b)]
arrange (Array (Int, Int) Double -> [Array Int Double]
makecols (Array (Int, Int) Double -> Array (Int, Int) Double
lu Array (Int, Int) Double
a0)) Int
1)
          makecol :: a -> a -> Array a e
makecol a
i a
n' = forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
array (a
1,a
n') [ (a
j, if a
i forall a. Eq a => a -> a -> Bool
== a
j then e
1.0 else e
0.0) | a
j <- [a
1..a
n'] ]
          makecols :: Array (Int, Int) Double -> [Array Int Double]
makecols Array (Int, Int) Double
a = [ Array (Int, Int) Double -> Array Int Double -> Array Int Double
lu_solve Array (Int, Int) Double
a (forall {a} {e}.
(Ix a, Enum a, Fractional e, Num a) =>
a -> a -> Array a e
makecol Int
i Int
n) | Int
i <- [Int
1..Int
n] ]
          ((Int
_,Int
_),(Int
n,Int
_)) = forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a0
          arrange :: [Array a b] -> t -> [((a, t), b)]
arrange []     t
_ = []
          arrange (Array a b
m:[Array a b]
ms) t
j = forall {a} {b} {b}. Ix a => Array a b -> b -> [((a, b), b)]
flatten Array a b
m t
j forall a. [a] -> [a] -> [a]
++ [Array a b] -> t -> [((a, t), b)]
arrange [Array a b]
ms (t
jforall a. Num a => a -> a -> a
+t
1)
          flatten :: Array a b -> b -> [((a, b), b)]
flatten Array a b
m b
j = forall a b. (a -> b) -> [a] -> [b]
map (\(a
i,b
x) -> ((a
i,b
j),b
x)) (forall i e. Ix i => Array i e -> [(i, e)]
assocs Array a b
m)

-- | Determinant of a matrix via LU decomposition

-- Formula (2.5.1) from NRIC

lu_det :: Array (Int,Int) Double -- ^ LU(A)
       -> Double -- ^ det(A)

lu_det :: Array (Int, Int) Double -> Double
lu_det Array (Int, Int) Double
a = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [ Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
i,Int
i) | Int
i <- [ Int
1 .. Int
n] ]
    where ((Int
_,Int
_),(Int
n,Int
_)) = forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a


-- | LU solver using original matrix

solve :: Array (Int,Int) Double -- ^ A
         -> Array Int Double -- ^ b
         -> Array Int Double -- ^ x

solve :: Array (Int, Int) Double -> Array Int Double -> Array Int Double
solve Array (Int, Int) Double
a Array Int Double
b = (Array (Int, Int) Double -> Array Int Double -> Array Int Double
lu_solve forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array (Int, Int) Double -> Array (Int, Int) Double
lu) Array (Int, Int) Double
a Array Int Double
b

-- | determinant using original matrix

{-
It is based on LU decomposition without singularity check
and thus returns NaN instead of zero if the matrix is singular.
-}
_det :: Array (Int,Int) Double -- ^ A
    -> Double -- ^ det(A)

_det :: Array (Int, Int) Double -> Double
_det Array (Int, Int) Double
a = (Array (Int, Int) Double -> Double
lu_det forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array (Int, Int) Double -> Array (Int, Int) Double
lu) Array (Int, Int) Double
a

{- |
Determinant computation by implicit LU decomposition with permutations.
-}
det :: Array (Int,Int) Double -- ^ A
    -> Double -- ^ det(A)

det :: Array (Int, Int) Double -> Double
det Array (Int, Int) Double
a =
   if forall a. Ix a => (a, a) -> Int
rangeSize (forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a) forall a. Eq a => a -> a -> Bool
== Int
0
     then Double
1
     else
         let ((Int
m0,Int
n0), (Int
m1,Int
n1)) = forall i e. Array i e -> (i, i)
bounds Array (Int, Int) Double
a
             v :: Array Int Double
v = forall i j e. (Ix i, Ix j) => j -> Array (i, j) e -> Array i e
Matrix.getColumn Int
n0 Array (Int, Int) Double
a
             (Int
maxi,Double
maxv) = forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
List.maximumBy (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (forall a. Num a => a -> a
abs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd)) forall a b. (a -> b) -> a -> b
$ forall i e. Ix i => Array i e -> [(i, e)]
assocs Array Int Double
v
             reduced :: Array (Int, Int) Double
reduced =
                forall i j e.
(Ix i, Ix j) =>
(i, i) -> (i -> j) -> Array j e -> Array i e
ixmap ((Int
m0,Int
n0), (forall a. Enum a => a -> a
pred Int
m1, forall a. Enum a => a -> a
pred Int
n1))
                   (\(Int
i,Int
j) -> (if Int
iforall a. Ord a => a -> a -> Bool
<Int
maxi then Int
i else forall a. Enum a => a -> a
succ Int
i, forall a. Enum a => a -> a
succ Int
j)) forall a b. (a -> b) -> a -> b
$
                forall i a. (Ix i, Num a) => Array i a -> Array i a -> Array i a
Vector.sub Array (Int, Int) Double
a forall a b. (a -> b) -> a -> b
$ forall i j a.
(Ix i, Ix j, Num a) =>
Array i a -> Array j a -> Array (i, j) a
Matrix.outer Array Int Double
v forall a b. (a -> b) -> a -> b
$
                forall i a. (Ix i, Num a) => a -> Array i a -> Array i a
Vector.scale (forall a. Fractional a => a -> a
recip Double
maxv) forall a b. (a -> b) -> a -> b
$ forall i j e. (Ix i, Ix j) => i -> Array (i, j) e -> Array j e
Matrix.getRow Int
maxi Array (Int, Int) Double
a
             sign :: Double
sign = if forall a. Integral a => a -> Bool
even (forall a. Ix a => (a, a) -> Int
rangeSize (Int
m0,Int
maxi)forall a. Num a => a -> a -> a
-Int
1) then Double
1 else -Double
1
             pivot :: Double
pivot = Array (Int, Int) Double
aforall i e. Ix i => Array i e -> i -> e
!(Int
maxi,Int
n0)
         in  if Double
pivot forall a. Eq a => a -> a -> Bool
== Double
0 then Double
0 else Double
sign forall a. Num a => a -> a -> a
* Double
pivot forall a. Num a => a -> a -> a
* Array (Int, Int) Double -> Double
det Array (Int, Int) Double
reduced


-------------------------------------------------------------------------------
-- tests
-------------------------------------------------------------------------------

{-

a = array ((1,1),(3,3)) [ ((1,1), 1.0), ((1,2), 2.0), ((1,3),  3.0),
                          ((2,1), 2.0), ((2,2), 5.0), ((2,3),  3.0),
                          ((3,1), 1.0), ((3,2), 0.0), ((3,3),  8.0) ]
a' = array ((1,1),(3,3)) [ ((1,1), -40.0), ((1,2), 16.0), ((1,3),  9.0),
                           ((2,1),  13.0), ((2,2), -5.0), ((2,3), -3.0),
                           ((3,1),   5.0), ((3,2), -2.0), ((3,3), -1.0) ]

a_lu = lu a
b = array (1,3) [ (1, 1.0), (2, 2.0), (3, 5.0) ]
x   = lu_solve a_lu b
x'  = improve a a_lu b x
x'' = improve a a_lu b x'

verify = a' == inverse a && -- tests lu, lu_solve, and inverse
         det a == -1 &&     -- tests lu_det
         x == x' &&         -- tests improve
         x' == x''

-}