```{- | HasGP Gaussian Process Library. This module contains assorted functions
that support the efficient solution of sets of linear equations

Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- This file is part of HasGP.

HasGP is free software: you can redistribute it and/or modify
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

HasGP is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Support.Solve where

import Numeric.LinearAlgebra
import Data.Packed.ST

import HasGP.Types.MainTypes
import HasGP.Support.Functions as F

{- | It's not clear whether the use of linearSolve from HMatrix will induce
a performance hit when the matrix is upper or lower triangular. Pro:
it's a call to something presumably from LaPack. Con: we've got some
structure that should allow us to make it O(n^2) instead of O(n^3).

To do: try some timed runs to see if these are needed.
-}

-- | Solve an upper triangular system.
upperSolve :: DMatrix -> DVector -> DVector
upperSolve m y = uS mR yL n x
where
n = rows m
x = constant 0.0 n
mR = reverse \$ toRows m
yL = reverse \$ toList y

-- | Solve a lower triangular system.
lowerSolve :: DMatrix -> DVector -> DVector
lowerSolve m y = lS mR yL 1 x
where
x = constant 0.0 \$ rows m
mR = toRows m
yL = toList y

-- | Used by lowerSolve.
lS [] [] n x = x
lS (row:rows) (y:ys) n x = lS rows ys (n+1) \$ computeNthElement row y n x
lS _ _ _ x = x

-- | Used by upperSolve.
uS [] [] n x = x
uS (row:rows) (y:ys) n x = uS rows ys (n-1) \$ computeNthElement row y n x
uS _ _ _ x = x

-- | Compute the value of x_n when solving a lower triangular
--   set of equations Mx=y. It is assumed that all values x_i where
--   i < n are already in the vector x and that the rest of the
--   elements of x are 0.
computeNthElement::DVector   -- ^ nth row of M
-> Double   -- ^ y_n
-> Int      -- ^ n
-> DVector  -- ^ current x vector
-> DVector  -- ^ x vector with x_n computed.
computeNthElement row y n x =
runSTVector \$ do
let inner = row <.> x
row' <- thawVector row
x' <- thawVector x
writeVector x' (n-1) ((y - inner)/mN)
return x'

-- | General solver for linear equations of the relevant kind.
--
--   First parameter is either upperSolve or lowerSolve. Next two parameters
--   are the upper/lower triangular matrix from the Cholesky decomposition,
--   then another matrix. Returns the solution as a matrix.
generalSolve :: (DMatrix -> DVector -> DVector)
-> DMatrix
-> DMatrix
-> DMatrix
generalSolve solver l m = fromColumns \$ map (solver l) (toColumns m)

-- | Find the inverse of a matrix from its Cholesky decomposition
cholSolve :: DMatrix -> DMatrix
cholSolve l = fromColumns \$ map (upperSolve l) (toColumns \$ ident \$ rows l)

```