{- | HasGP Gaussian Process Library. This module contains assorted functions 
     that support the efficient solution of sets of linear equations

     Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- This file is part of HasGP.

   HasGP is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   HasGP is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Support.Solve where

import Numeric.LinearAlgebra
import Data.Packed.ST

import HasGP.Types.MainTypes
import HasGP.Support.Functions as F

import Control.Monad.ST

{- | It's not clear whether the use of linearSolve from HMatrix will induce
     a performance hit when the matrix is upper or lower triangular. Pro: 
     it's a call to something presumably from LaPack. Con: we've got some 
     structure that should allow us to make it O(n^2) instead of O(n^3).

     To do: try some timed runs to see if these are needed.
-}

-- | Solve an upper triangular system. 
upperSolve :: DMatrix -> DVector -> DVector
upperSolve m y = uS mR yL n x
    where 
      n = rows m
      x = constant 0.0 n
      mR = reverse $ toRows m
      yL = reverse $ toList y

-- | Solve a lower triangular system.
lowerSolve :: DMatrix -> DVector -> DVector
lowerSolve m y = lS mR yL 1 x
    where 
      x = constant 0.0 $ rows m
      mR = toRows m
      yL = toList y

-- | Used by lowerSolve.
lS [] [] n x = x
lS (row:rows) (y:ys) n x = lS rows ys (n+1) $ computeNthElement row y n x
lS _ _ _ x = x 

-- | Used by upperSolve.
uS [] [] n x = x
uS (row:rows) (y:ys) n x = uS rows ys (n-1) $ computeNthElement row y n x
uS _ _ _ x = x 
    
-- | Compute the value of x_n when solving a lower triangular 
--   set of equations Mx=y. It is assumed that all values x_i where 
--   i < n are already in the vector x and that the rest of the 
--   elements of x are 0.
computeNthElement::DVector   -- ^ nth row of M
                 -> Double   -- ^ y_n    
                 -> Int      -- ^ n
                 -> DVector  -- ^ current x vector
                 -> DVector  -- ^ x vector with x_n computed.
computeNthElement row y n x = 
     runSTVector $ do
          let inner = row <.> x
          row' <- thawVector row 
          mN <- readVector row' (n-1)
          x' <- thawVector x
          writeVector x' (n-1) ((y - inner)/mN)
          return x'

-- | General solver for linear equations of the relevant kind. 
--
--   First parameter is either upperSolve or lowerSolve. Next two parameters 
--   are the upper/lower triangular matrix from the Cholesky decomposition, 
--   then another matrix. Returns the solution as a matrix.
generalSolve :: (DMatrix -> DVector -> DVector) 
             -> DMatrix 
             -> DMatrix 
             -> DMatrix
generalSolve solver l m = fromColumns $ map (solver l) (toColumns m) 

-- | Find the inverse of a matrix from its Cholesky decomposition
cholSolve :: DMatrix -> DMatrix
cholSolve l = fromColumns $ map (upperSolve l) (toColumns $ ident $ rows l)