{- | ClassificationLaplace is a module in the HasGP Gaussian Process
   library. It implements basic Gaussian Process Classification for two 
   classes using the Laplace approximation. For details see 
   www.gaussianprocesses.org.

   Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- This file is part of HasGP.

   HasGP is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   HasGP is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Classification.Laplace.ClassificationLaplace 
   (
     LaplaceValue(..),
     LaplaceConvergenceTest,
     gpCLaplaceLearn,
     convertToP_CG,
     gpCLaplacePredict,
     gpCLaplacePredict',
     gpCLaplaceLogEvidence,
     gpCLaplaceLogEvidenceList,
     gpCLaplaceLogEvidenceVec
   ) where 

import Numeric.LinearAlgebra

import Control.Monad.State

import HasGP.Types.MainTypes
import HasGP.Support.MatrixFunction
import HasGP.Support.Linear
import HasGP.Support.Functions
import HasGP.Support.Solve
import HasGP.Support.Iterate
import HasGP.Covariance.Basic
import HasGP.Likelihood.Basic

-- | Computing the Laplace approximation requires us to deal with 
--   quite a lot of information. To keep things straightforward we 
--   wrap this up in a type.
--
--   The value associated with a state includes f, evidence, objective, 
--   derivative of the objective, the vector a needed to compute the 
--   derivative of the evidence, and the number of iterations.
data LaplaceValue = LaplaceValue {
      fValue::DVector,
      eValue::Double,
      psiValue::Double,
      dPsiValue::DVector,
      aValue::DVector,
      count::Int
    }

-- | The state is the vector f and the number of iterations.
type LaplaceState = (DVector,Int)

-- | A convergence test is a function that takes two consecutive values 
--   during iteration and works out whether you've converged or not.
type LaplaceConvergenceTest = (LaplaceValue -> LaplaceValue -> Bool)

-- | Compute the Laplace update for the latent variables f.
--
--   Produces new f, log marginal likelihood, objective, derivative 
--   of objective, and the vector a which is needed to compute the derivative 
--   of the log marginal likelihood.
gpCLaplaceUpdate :: LogLikelihood l => CovarianceMatrix 
                 -> Targets                  
                 -> l               -- ^ log likelihood
                 -> LaplaceState    -- ^ Current f and n.
                 -> LaplaceValue
gpCLaplaceUpdate c t like (f,n) = 
    LaplaceValue newF e psi dPsi aV (n+1)
        where
          d = dim t
          diagW = -(zipVectorWith (ddLikelihood like) t f)
          w = diag diagW
          diagRootW = mapVector sqrt diagW
          rootW = diag diagRootW
          -- MUST use abaDiagDiag, or errors accumulate and symmetry is lost.
          ll = trans $ chol ((ident d) + (abaDiagDiag diagRootW c))
          dL = zipVectorWith (dLikelihood like) t f                   
          b = (w <> (asColumn f)) + (asColumn dL)    
          a = b - 
              (rootW <> 
               asColumn (upperSolve (trans ll) 
                         (lowerSolve ll (head $ toColumns (rootW <> c <> b)))))
          aV = head $ toColumns a
          newF = head $ toColumns (c <> a)
          psi = (-(1/2) * (aV <.> newF)) + 
                (sum $ toList $ zipVectorWith (likelihood like) t newF)
          -- Not in the book but easily proved using the fact that f = Ka and 
          -- dPsi = d log likelihood - (inverse of K) f
          dPsi = dL - aV 
          -- log marginal likelihood.
          e = (psi - (sum $ map log $ toList $ takeDiag ll)) 
                        
-- | Iteration to convergence is much nicer if the state is hidden using 
--   the State monad.
--
--   This uses the pure gpCLaplaceUpdate function, and wraps it up in a 
--   state transformer that's usable by the general functions in 
--   HasGP.Support.Iterate.
singleIteration::LogLikelihood l => CovarianceMatrix 
               -> Targets                 
               -> l                        -- ^ log likelihood
               -> State LaplaceState LaplaceValue
singleIteration c t like = state sI 
    where
      sI (f, n) = (newValue, ((fValue newValue), (n+1)))
          where
            newValue = gpCLaplaceUpdate c t like (f,n)

-- | Iteration to convergence is much nicer if the state is hidden using 
--   the State monad.
--
--   This uses a general function from HasGP.Support.Iterate to implement 
--   the learning algorithm. Convergence testing is done using a user 
--   supplied function.
gpCLaplaceLearn::LogLikelihood l => CovarianceMatrix 
               -> Targets   
               -> l          -- ^ log likelihood
               -> LaplaceConvergenceTest
               -> LaplaceValue
gpCLaplaceLearn c t like converged = 
    evalState (iterateToConvergence'' doOnce converged) (constant 0.0 (dim t),1)
        where
          doOnce = singleIteration c t like

-- | Converts pairs of fStar and V produced by the prediction functions 
--   to actual probabilities, assuming the cumulative Gaussian likelihood 
--   was used.
convertToP_CG :: (Double,Double) -> Double
convertToP_CG (fStar,v) = phiIntegral (fStar / (sqrt (1 + v)))

-- | Predict using a GP classifier based on the Laplace approximation.
--
--   Produces fStar and V rather than the actual probability as 
--   further approximations are then required to compute this.
gpCLaplacePredict :: (CovarianceFunction cF, LogLikelihood l) => DVector -- ^ f
                  -> Inputs           
                  -> Targets          
                  -> CovarianceMatrix -- ^ Covariance matrix
                  -> cF               -- ^ Covariance function
                  -> l                -- ^ log likelihood
                  -> Input            -- ^ Input to classify
                  -> (Double,Double)
gpCLaplacePredict f inputs t c cov like x = 
    (fStar, (((covariance cov) x x) - (v <.> v)))
        where
          d = dim t
          diagW = -(zipVectorWith (ddLikelihood like) t f)
          w = diag diagW
          diagRootW = mapVector sqrt diagW
          rootW = diag diagRootW
          ll = trans $ chol ((ident d) + (abaDiagDiag diagRootW c))   
          kxxStar = covarianceWithPoint cov inputs x 
          fStar = kxxStar <.> (zipVectorWith (dLikelihood like) t f) 
          v = lowerSolve ll (zipVectorWith (*) diagRootW kxxStar)

-- | Predict using a GP classifier based on the Laplace approximation.
--
--   The same as gpLaplacePredict but applies to a collection of new 
--   inputs supplied as the rows of a matrix.
--
--   Produces a list of pairs of fStar and V rather than the actual 
--   probabilities as further approximations are then required to compute 
--   these.
gpCLaplacePredict' :: (CovarianceFunction cF, LogLikelihood l) => DVector -- ^ f
                   -> Inputs          
                   -> Targets         
                   -> CovarianceMatrix 
                   -> cF               -- ^ Covariance function
                   -> l                -- ^ log likelihood
                   -> Inputs           -- ^ Inputs to classify
                   -> [(Double,Double)]
gpCLaplacePredict' f inputs t c cov like x = 
    map predict $ toRows x
        where
          predict = gpCLaplacePredict f inputs t c cov like

-- | Compute the log marginal likelihood and its first derivative for the 
--   Laplace approximation for GP classification.
--
--   The convergence test input tests for convergence when 
--   using gpClassificationLaplaceLearn. Note that a covariance function 
--   contains its own parameters and can compute its own derivative so 
--   theta does not need to be passed seperately.
--
--   Outputs the NEGATIVE log marginal likelihood and a vector of its 
--   derivatives. The derivatives are with respect to the actual, NOT log 
--   parameters.
gpCLaplaceLogEvidence :: (CovarianceFunction cF, LogLikelihood l) => Inputs 
                      -> Targets 
                      -> cF                  -- ^ Covariance function
                      -> l                   -- ^ log likelihood
                      -> LaplaceConvergenceTest 
                      -> (Double, DVector)
gpCLaplaceLogEvidence i t cov like converged = 
    (-z, -dZ)
        where
          d = dim t
          cM = covarianceMatrix cov i
          LaplaceValue f z psi dPsi aV n = 
              gpCLaplaceLearn cM t like converged
          diagW = -(zipVectorWith (ddLikelihood like) t f)
          w = diag diagW
          diagRootW = mapVector sqrt diagW
          rootW = diag diagRootW
          ll = trans $ chol ((ident d) + (abaDiagDiag diagRootW cM))
          r = rootW <> (generalSolve upperSolve (trans ll) 
                        (generalSolve lowerSolve ll rootW))
          c = generalSolve lowerSolve ll (rootW <> cM)
          s2 = flatten $ 
               (-0.5) * (diag ((takeDiag cM) - 
                               (abDiagOnly (trans c) c))) <> 
                          (asColumn (zipVectorWith (dddLikelihood like) t f))
          cList = makeMatricesFromPairs (dCovarianceDParameters cov) i
          s1a = fromList $ map (abaVV aV) cList
          s1b = fromList $ map (sum . toList . (abDiagOnly r)) cList
          s1 = 0.5 * (s1a - s1b)
          b = map (<> (asColumn $ zipVectorWith (dLikelihood like) t f)) cList
          s3 = map (\x -> x - (cM <> r <> x)) b 
          dZ = s1 + (fromList $ map (s2 <.>) (map flatten s3)) 

-- | A version of gpClassificationLaplaceEvidence that's usable by the
--   conjugate gradient function included in the hmatrix library. Computes 
--   the log evidence and its first derivative for the Laplace approximation 
--   for GP classification. The issue is that while it makes sense for a 
--   covariance function to be implemented as a class so that any can easily 
--   be used, we need to supply evidence and its derivatives directly as 
--   functions of the hyperparameters, and these have to be supplied as 
--   vectors of Doubles. The solution is to include a function in the 
--   CovarianceFunction class that takes a list and returns a new covariance 
--   function of the required type having the specified hyperparameters.
--
--   Parameters: The same parameters as gpClassifierLaplaceEvidence, plus 
--   the list of hyperparameters. Outputs: negative log marginal likelihood 
--   and a vector of its first derivatives. 
--   
--   In addition to the above, this assumes that we want derivatives with 
--   respect to log parameters and so converts using df/d log p = 
--   p df/dp.
gpCLaplaceLogEvidenceList :: (CovarianceFunction cF, LogLikelihood l) => Inputs 
                          -> Targets 
                          -> cF 
                          -> l 
                          -> LaplaceConvergenceTest
                          -> [Double] -- ^ log hyperparameters
                          -> (Double, DVector)
gpCLaplaceLogEvidenceList i t cov like converged hyper = 
    (negZ, zipVectorWith (*) (fromList $ map exp hyper) negDZ)
        where 
          cov2 = makeCovarianceFromList cov hyper
          (negZ, negDZ) = gpCLaplaceLogEvidence i t cov2 like converged

-- | This is the same as gpCLaplaceLogEvidenceList but takes a vector 
--   instead of a list.
gpCLaplaceLogEvidenceVec :: (CovarianceFunction cF, LogLikelihood l) => Inputs 
                         -> Targets 
                         -> cF 
                         -> l 
                         -> LaplaceConvergenceTest
                         -> DVector 
                         -> (Double, DVector)
gpCLaplaceLogEvidenceVec i t cov like converged 
    = (gpCLaplaceLogEvidenceList i t cov like converged) . toList