-----------------------------------------------------------------------------
-- |
-- Module      :  Math.Tensor.Internal.LinearAlgebra
-- Copyright   :  (c) 2019 Tobias Reinhart and Nils Alex
-- License     :  MIT
-- Maintainer  :  tobi.reinhart@fau.de, nils.alex@fau.de
--
-- Gaussian elimination algorithm based on hmatrix.
-----------------------------------------------------------------------------
module Math.Tensor.Internal.LinearAlgebra (
-- * Gaussian Elimination
gaussianST,
gaussian,
-- * Linearly Independent Columns
independentColumns,
independentColumnsMat,
-- * Pivots
pivotsU,
findPivotMax)

where

import Numeric.LinearAlgebra
import Numeric.LinearAlgebra.Data
import Numeric.LinearAlgebra.Devel

import Data.List (maximumBy)

import Control.Monad
import Control.Monad.ST

-- | Returns the pivot columns of an upper triangular matrix.
--
-- @
-- &#x3BB; let mat = (3 >< 4) [1, 0, 2, -3, 0, 0, 1, 0, 0, 0, 0, 0]
-- &#x3BB; mat
-- (3><4)
--  [ 1.0, 0.0, 2.0, -3.0
--  , 0.0, 0.0, 1.0,  0.0
--  , 0.0, 0.0, 0.0,  0.0 ]
-- &#x3BB; pivotsU mat
-- [0,2]
-- @
--

pivotsU :: Matrix Double -> [Int]
pivotsU :: Matrix Double -> [Int]
pivotsU Matrix Double
mat = (Int, Int) -> [Int]
go (Int
0,Int
0)
  where
    go :: (Int, Int) -> [Int]
go (Int
i,Int
j)
      = case Matrix Double -> Double -> (Int, Int) -> Maybe (Int, Int)
findPivot Matrix Double
mat Double
e (Int
i,Int
j) of
          Maybe (Int, Int)
Nothing       -> []
          Just (Int
i', Int
j') -> Int
j' Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: (Int, Int) -> [Int]
go (Int
i'Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Int
j'Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
    maxAbs :: Double
maxAbs = [Double] -> Double
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ ([Double] -> Double) -> [[Double]] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map ([Double] -> Double
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Double] -> Double)
-> ([Double] -> [Double]) -> [Double] -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> [Double] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map Double -> Double
forall a. Num a => a -> a
abs) ([[Double]] -> [Double]) -> [[Double]] -> [Double]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
toLists Matrix Double
mat
    e :: Double
e = Double
eps Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
maxAbs


eps :: Double
eps :: Double
eps = Double
1e-12

-- find next pivot in upper triangular matrix

findPivot :: Matrix Double -> Double -> (Int, Int) -> Maybe (Int, Int)
findPivot :: Matrix Double -> Double -> (Int, Int) -> Maybe (Int, Int)
findPivot Matrix Double
mat Double
e (Int
i, Int
j)
    | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Maybe (Int, Int)
forall a. Maybe a
Nothing
    | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Maybe (Int, Int)
forall a. Maybe a
Nothing
    | Bool
otherwise = case [(Int, Int)]
nonZeros of
                    []          -> if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
                                   then Maybe (Int, Int)
forall a. Maybe a
Nothing
                                   else Matrix Double -> Double -> (Int, Int) -> Maybe (Int, Int)
findPivot Matrix Double
mat Double
e (Int
i, Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
                    (Int
pi, Int
pj):[(Int, Int)]
_  -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
pi, Int
pjInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
j)
    where
        m :: Int
m = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
mat
        n :: Int
n = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
mat
        col :: Matrix Double
col = Matrix Double
mat Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
¿ [Int
j]
        nonZeros :: [(Int, Int)]
nonZeros = ((Int, Int) -> Bool) -> [(Int, Int)] -> [(Int, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Int
i', Int
_) -> Int
i' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i) ([(Int, Int)] -> [(Int, Int)]) -> [(Int, Int)] -> [(Int, Int)]
forall a b. (a -> b) -> a -> b
$ (Double -> Bool) -> Matrix Double -> [IndexOf Matrix]
forall (c :: * -> *) e.
Container c e =>
(e -> Bool) -> c e -> [IndexOf c]
find (Bool -> Bool
not (Bool -> Bool) -> (Double -> Bool) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
e) (Double -> Bool) -> (Double -> Double) -> Double -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Num a => a -> a
abs) Matrix Double
col

-- | Find pivot element below position (i, j) with greatest absolute value in the ST monad.

findPivotMax :: Int -> Int -> Int -> Int -> STMatrix s Double -> ST s (Maybe (Int, Int))
findPivotMax :: Int
-> Int
-> Int
-> Int
-> STMatrix s Double
-> ST s (Maybe (Int, Int))
findPivotMax Int
m Int
n Int
i Int
j STMatrix s Double
mat
    | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Maybe (Int, Int) -> ST s (Maybe (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Int, Int)
forall a. Maybe a
Nothing
    | Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Maybe (Int, Int) -> ST s (Maybe (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Int, Int)
forall a. Maybe a
Nothing
    | Bool
otherwise =
        do
          [(Int, Double)]
col      <- (Int -> ST s (Int, Double)) -> [Int] -> ST s [(Int, Double)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\Int
i' -> do
                                    Double
x <- STMatrix s Double -> Int -> Int -> ST s Double
forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix STMatrix s Double
mat Int
i' Int
j
                                    (Int, Double) -> ST s (Int, Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
i', Double -> Double
forall a. Num a => a -> a
abs Double
x))
                      [Int
i..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
          let nonZeros :: [(Int, Double)]
nonZeros = ((Int, Double) -> Bool) -> [(Int, Double)] -> [(Int, Double)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ((Int, Double) -> Bool) -> (Int, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<Double
eps) (Double -> Bool)
-> ((Int, Double) -> Double) -> (Int, Double) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double
forall a. Num a => a -> a
abs (Double -> Double)
-> ((Int, Double) -> Double) -> (Int, Double) -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Double) -> Double
forall a b. (a, b) -> b
snd) [(Int, Double)]
col
          let (Int
pi, Double
_) = ((Int, Double) -> (Int, Double) -> Ordering)
-> [(Int, Double)] -> (Int, Double)
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (\(Int
_, Double
x) (Int
_, Double
y) -> Double
x Double -> Double -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Double
y) [(Int, Double)]
nonZeros
          case [(Int, Double)]
nonZeros of
            [] -> if Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
                  then Maybe (Int, Int) -> ST s (Maybe (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Int, Int)
forall a. Maybe a
Nothing
                  else Int
-> Int
-> Int
-> Int
-> STMatrix s Double
-> ST s (Maybe (Int, Int))
forall s.
Int
-> Int
-> Int
-> Int
-> STMatrix s Double
-> ST s (Maybe (Int, Int))
findPivotMax Int
m Int
n Int
i (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) STMatrix s Double
mat
            [(Int, Double)]
_  -> Maybe (Int, Int) -> ST s (Maybe (Int, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Int, Int) -> ST s (Maybe (Int, Int)))
-> Maybe (Int, Int) -> ST s (Maybe (Int, Int))
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
pi, Int
j)

-- gaussian elimination of sub matrix below position (i, j)

gaussian' :: Int -> Int -> Int -> Int -> STMatrix s Double -> ST s ()
gaussian' :: Int -> Int -> Int -> Int -> STMatrix s Double -> ST s ()
gaussian' Int
m Int
n Int
i Int
j STMatrix s Double
mat = do
    Maybe (Int, Int)
iPivot' <- Int
-> Int
-> Int
-> Int
-> STMatrix s Double
-> ST s (Maybe (Int, Int))
forall s.
Int
-> Int
-> Int
-> Int
-> STMatrix s Double
-> ST s (Maybe (Int, Int))
findPivotMax Int
m Int
n Int
i Int
j STMatrix s Double
mat
    case Maybe (Int, Int)
iPivot' of
        Maybe (Int, Int)
Nothing     -> () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just (Int
r, Int
p) -> do
                          RowOper Double -> STMatrix s Double -> ST s ()
forall t s.
(Num t, Element t) =>
RowOper t -> STMatrix s t -> ST s ()
rowOper (Int -> Int -> ColRange -> RowOper Double
forall t. Int -> Int -> ColRange -> RowOper t
SWAP Int
i Int
r (Int -> ColRange
FromCol Int
j)) STMatrix s Double
mat
                          Double
pv <- STMatrix s Double -> Int -> Int -> ST s Double
forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix STMatrix s Double
mat Int
i Int
p
                          (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Double -> Int -> Int -> ST s ()
reduce Double
pv Int
p) [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 .. Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
                          Int -> Int -> Int -> Int -> STMatrix s Double -> ST s ()
forall s. Int -> Int -> Int -> Int -> STMatrix s Double -> ST s ()
gaussian' Int
m Int
n (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) STMatrix s Double
mat
  where
    reduce :: Double -> Int -> Int -> ST s ()
reduce Double
pv Int
p Int
r = do
                      Double
rv <- STMatrix s Double -> Int -> Int -> ST s Double
forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix STMatrix s Double
mat Int
r Int
p
                      if Double -> Double
forall a. Num a => a -> a
abs Double
rv Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
eps
                        then () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                        else
                         let frac :: Double
frac = -Double
rv Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
pv
                             op :: RowOper Double
op   = Double -> Int -> Int -> ColRange -> RowOper Double
forall t. t -> Int -> Int -> ColRange -> RowOper t
AXPY Double
frac Int
i Int
r (Int -> ColRange
FromCol Int
p)
                         in do
                             RowOper Double -> STMatrix s Double -> ST s ()
forall t s.
(Num t, Element t) =>
RowOper t -> STMatrix s t -> ST s ()
rowOper RowOper Double
op STMatrix s Double
mat
                             (Int -> ST s ()) -> [Int] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (\Int
j' -> STMatrix s Double -> Int -> Int -> (Double -> Double) -> ST s ()
forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> (t -> t) -> ST s ()
modifyMatrix STMatrix s Double
mat Int
r Int
j' (\Double
x -> if Double -> Double
forall a. Num a => a -> a
abs Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
eps then Double
0 else Double
x)) [Int
p..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Gaussian elimination perfomed in-place in the @'ST'@ monad.

gaussianST :: Int -> Int -> STMatrix s Double -> ST s ()
gaussianST :: Int -> Int -> STMatrix s Double -> ST s ()
gaussianST Int
m Int
n = Int -> Int -> Int -> Int -> STMatrix s Double -> ST s ()
forall s. Int -> Int -> Int -> Int -> STMatrix s Double -> ST s ()
gaussian' Int
m Int
n Int
0 Int
0


-- | Gaussian elimination as pure function. Involves a copy of the input matrix.
--
-- @
-- &#x3BB; let mat = (3 >< 4) [1, 1, -2, 0, 0, 2, -6, -4, 3, 0, 3, 1]
-- &#x3BB; mat
-- (3><4)
--  [ 1.0, 1.0, -2.0,  0.0
--  , 0.0, 2.0, -6.0, -4.0
--  , 3.0, 0.0,  3.0,  1.0 ]
-- &#x3BB; gaussian mat
-- (3><4)
--  [ 3.0, 0.0,  3.0,                1.0
--  , 0.0, 2.0, -6.0,               -4.0
--  , 0.0, 0.0,  0.0, 1.6666666666666667 ]
-- @
--

gaussian :: Matrix Double -> Matrix Double
gaussian :: Matrix Double -> Matrix Double
gaussian Matrix Double
mat = (forall s. ST s (Matrix Double)) -> Matrix Double
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Matrix Double)) -> Matrix Double)
-> (forall s. ST s (Matrix Double)) -> Matrix Double
forall a b. (a -> b) -> a -> b
$ do
    STMatrix s Double
matST <- Matrix Double -> ST s (STMatrix s Double)
forall t s. Element t => Matrix t -> ST s (STMatrix s t)
thawMatrix Matrix Double
mat
    Int -> Int -> STMatrix s Double -> ST s ()
forall s. Int -> Int -> STMatrix s Double -> ST s ()
gaussianST Int
m Int
n STMatrix s Double
matST
    STMatrix s Double -> ST s (Matrix Double)
forall t s. Element t => STMatrix s t -> ST s (Matrix t)
freezeMatrix STMatrix s Double
matST
  where
    m :: Int
m = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
mat
    n :: Int
n = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
mat

-- | Returns the indices of a maximal linearly independent subset of the columns
--   in the matrix.
--
-- @
-- &#x3BB; let mat = (3 >< 4) [1, 1, -2, 0, 0, 2, -6, -4, 3, 0, 3, 1]
-- &#x3BB; mat
-- (3><4)
--  [ 1.0, 1.0, -2.0,  0.0
--  , 0.0, 2.0, -6.0, -4.0
--  , 3.0, 0.0,  3.0,  1.0 ]
-- &#x3BB; independentColumns mat
-- [0,1,3]
-- @
--

independentColumns :: Matrix Double -> [Int]
independentColumns :: Matrix Double -> [Int]
independentColumns Matrix Double
mat = Matrix Double -> [Int]
pivotsU Matrix Double
mat'
    where
        mat' :: Matrix Double
mat' = Matrix Double -> Matrix Double
gaussian Matrix Double
mat

-- | Returns a sub matrix containing a maximal linearly independent subset of
--   the columns in the matrix.
--
-- @
-- &#x3BB; let mat = (3 >< 4) [1, 1, -2, 0, 0, 2, -6, -4, 3, 0, 3, 1]
-- &#x3BB; mat
-- (3><4)
--  [ 1.0, 1.0, -2.0,  0.0
--  , 0.0, 2.0, -6.0, -4.0
--  , 3.0, 0.0,  3.0,  1.0 ]
-- &#x3BB; independentColumnsMat mat
-- (3><3)
--  [ 1.0, 1.0,  0.0
--  , 0.0, 2.0, -4.0
--  , 3.0, 0.0,  1.0 ]
-- @
--

independentColumnsMat :: Matrix Double -> Matrix Double
independentColumnsMat :: Matrix Double -> Matrix Double
independentColumnsMat Matrix Double
mat =
  case Matrix Double -> [Int]
independentColumns Matrix Double
mat of
    [] -> (Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
mat Int -> Int -> [Double] -> Matrix Double
forall a. Storable a => Int -> Int -> [a] -> Matrix a
>< Int
1) ([Double] -> Matrix Double) -> [Double] -> Matrix Double
forall a b. (a -> b) -> a -> b
$ Double -> [Double]
forall a. a -> [a]
repeat Double
0
    [Int]
cs -> Matrix Double
mat Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
¿ [Int]
cs