{- | ClassificationEP is a module in the HasGP Gaussian Process
   library. It implements basic Gaussian Process Classification for two 
   classes using the EP approximation. Targets should be +1/-1. 
   
   Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- This file is part of HasGP.

   HasGP is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   HasGP is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Classification.EP.ClassificationEP 
   ( 
     EPValue(eValue,siteState,count),
     EPConvergenceTest,
     EPSiteState,
     EPState,
     SiteOrder,
     generateRandomSiteOrder,
     generateFixedSiteOrder,
     gpClassifierEPEvidence,
     gpClassifierEPLearn,
     gpClassifierEPPredict,
     gpClassifierEPLogEvidence,
     gpClassifierEPLogEvidenceList,
     gpClassifierEPLogEvidenceVec
   ) where 

import Numeric.LinearAlgebra

import Control.Monad.State    
import System.Random

import HasGP.Types.MainTypes
import HasGP.Support.MatrixFunction
import HasGP.Support.Linear
import HasGP.Support.Functions
import HasGP.Support.Solve
import HasGP.Support.Iterate
import HasGP.Support.Random
import HasGP.Covariance.Basic
import HasGP.Likelihood.Basic

-- | A convergence test for EP usually depends on the evidence and 
--   the number of iterations so far. This allows us to specify 
--   completely arbitrary convergence tests.
data EPValue = EPValue {
      eValue::Double,
      siteState::EPSiteState,
      count::Int
    }

-- | By passing a function with this type we can specify arbitrary 
--   convergence tests.   
type EPConvergenceTest = (EPValue -> EPValue -> Bool)

-- | When updating a single site at a time you keep track of var, 
--   tauTilde, mu, nuTilde, TauMinus, and MuMinus.
data EPSiteState = EPSiteState {
     var::DMatrix,                  
     tauTilde::DVector, 
     mu::DVector, 
     nuTilde::DVector,
     tauMinus::DVector,
     muMinus::DVector
   }       

-- | We hide the state used in performing EP using the state monad. 
--   We need to include a random number generator and the number of 
--   iterations.
type EPState = (EPSiteState, StdGen, Int)

-- | If we're updating sites in a random order then we need access to 
--   the random number generator. 
type SiteOrder = State EPState [Int]

-- | Generates a basic start state for the sites, with var = covariance 
--   matrix and all vectors = 0.
generateInitialSiteState :: CovarianceMatrix  
                         -> Int     -- ^ Number of sites
                         -> EPSiteState
generateInitialSiteState k n = EPSiteState k z1 z2 z3 z4 z5
    where
      z1 = constant 0.0 n
      z2 = constant 0.0 n
      z3 = constant 0.0 n
      z4 = constant 0.0 n
      z5 = constant 0.0 n

-- | Updates for the EP version of Gaussian Process Classifiers.
--   cavityParameters, marginalMoments and siteParameters 
--   are successive parts of the update for a single site.
cavityParameters :: Double -- ^ varI
                 -> Double -- ^ tauTildeI
                 -> Double -- ^ muI
                 -> Double -- ^ nuTildeI
                 -> (Double,Double)
cavityParameters varI tauTildeI muI nuTildeI = (tauMinusI, nuMinusI)
    where
      varIInv = 1 / varI
      tauMinusI = varIInv - tauTildeI 
      nuMinusI = (varIInv * muI) - nuTildeI

marginalMoments :: Double -- ^ muMinusI
                -> Double -- ^ tI
                -> Double -- ^ varMinusI
                -> (Double,Double)
marginalMoments muMinusI tI varMinusI    = (muHatI, varHatI)
    where 
      zI = (tI * muMinusI) / (sqrt (1 + varMinusI))
      nopz = nOverPhi zI
      muHatI = muMinusI + (((tI * varMinusI) / (sqrt (1 + varMinusI))) * nopz)
      varHatI = varMinusI - 
                ((((square varMinusI) / (1 + varMinusI)) * nopz) * (zI + nopz))

siteParameters :: Double -- ^ tauTildeI
               -> Double -- ^ varHatI
               -> Double -- ^ tauMinusI
               -> Double -- ^ muHatI
               -> Double -- ^ nuMinusI
               -> (Double,Double,Double)
siteParameters tauTildeI varHatI tauMinusI muHatI nuMinusI = 
    (deltaTauTilde, tauTildeINew, nuTildeINew)
        where
          deltaTauTilde = (1 / varHatI) - tauMinusI - tauTildeI
          tauTildeINew = tauTildeI + deltaTauTilde
          nuTildeINew = ((1 / varHatI) * muHatI) - nuMinusI

-- | Do a complete update for site i.
updateOneSite :: Targets -- ^ Labels
              -> Int     -- ^ Number of sites
              -> Int     -- ^ Site to update
              -> EPSiteState
              -> EPSiteState
updateOneSite t n i (EPSiteState var tauTilde mu nuTilde oldTauMinus 
                     oldMuMinus) = 
    -- oldTauMinusI, oldMuMinus not used but need to be updated.
    EPSiteState newVar newTauTilde newMu newNuTilde newTauMinus newMuMinus 
        where      
          -- hmatrix counts from 0 whereas I'm using 1.
          i2 = i-1 
          tauTildeI = (tauTilde @> i2)
          varII = var @@> (i2,i2)
          (tauMinusI, nuMinusI) = 
              cavityParameters varII tauTildeI (mu @> i2) (nuTilde @> i2)
          varMinusI = (1 / tauMinusI)
          muMinusI = (nuMinusI / tauMinusI)
          (muHatI, varHatI) = marginalMoments muMinusI (t @> i2) varMinusI
          (deltaTauTilde, tauTildeINew, nuTildeINew) = 
              siteParameters tauTildeI varHatI tauMinusI muHatI nuMinusI
          newTauTilde = replaceInVector tauTilde (i2+1) tauTildeINew 
          newNuTilde = replaceInVector nuTilde (i2+1) nuTildeINew
          newTauMinus = replaceInVector oldTauMinus (i2+1) tauMinusI 
          newMuMinus = replaceInVector oldMuMinus (i2+1) muMinusI
          sI = subMatrix (0,i2) (n,1) var
          newVar = var - (scale (deltaTauTilde / (1 + (deltaTauTilde * varII)))
                          (sI <> (trans sI)))
          newMu = flatten $ (newVar <> (asColumn newNuTilde))

-- | Generate a random permutation. This is wrapped up in the state 
--   transformer generateRandomSiteOrder.
randomPermutation :: StdGen           -- ^ Random number generator
                  -> Int              -- ^ Size of list required
                  -> (StdGen, [Int])  -- ^ New generator and result.
randomPermutation g n = rP g (n-1) [1..n] []
    where
      rP g' n' [] result = (g', result)
      rP g' n' [x] result = (g', x:result)
      rP g' n' xs result = rP newG (n' - 1) newXs (m:result)
          where
            (r, newG) = randomR (0, n') g'
            m = xs !! r
            newXs = filter (\x -> x /= m) xs 

-- | We're often going to want to update sites in a random order. 
--   So we need a state transformer that takes the current state (which 
--   includes a random number generator) and produces a random permutation.
generateRandomSiteOrder :: SiteOrder
generateRandomSiteOrder = do 
  (state, g, n) <- get
  let (newG, p) = randomPermutation g (dim $ tauTilde state)
  put (state, newG, n)
  return p

-- | For completeness: just in case you want to update sites in a 
--   non-random manner, this state transformer does exactly that.
generateFixedSiteOrder :: SiteOrder
generateFixedSiteOrder = do
  (state, g, n) <- get
  return [1..(dim $ tauTilde state)]

-- | Update all the sites in the order specified by a list of Ints
updateAllSites :: Targets 
               -> Int     -- ^ Number of sites
               -> [Int]   -- ^ Sites to update
               -> EPSiteState
               -> EPSiteState
updateAllSites t n [] state = state
updateAllSites t n (s:ss) state = updateAllSites t n ss newState
    where
      newState = updateOneSite t n s state

-- | Re-compute the approximation after updating all the sites.
--   Outputs \Sigma and \mu.
recomputeApproximation :: CovarianceMatrix
                       -> Int     -- ^ Number of sites
                       -> DVector -- ^ tauTilde
                       -> DVector -- ^ nuTilde
                       -> (DMatrix, DMatrix, DVector)
recomputeApproximation k n tauTilde nuTilde = 
  (l, finalVar, finalMu)
    where
      rootSTilde = mapVector sqrt tauTilde
      l = trans $ chol ((ident n) + (abaDiagDiag rootSTilde k))
      v = generalSolve lowerSolve l (preMultiply rootSTilde k) 
      finalVar = k - ((trans v) <> v) 
      finalMu = flatten $ finalVar <> (asColumn nuTilde) 

-- | Compute the approximation to the log marginal likelihood.
gpClassifierEPEvidence :: CovarianceMatrix 
                       -> Targets 
                       -> DMatrix -- ^ L matrix
                       -> EPSiteState
                       -> Double  -- ^ log marginal likelihood.
gpClassifierEPEvidence k t l state = 
    (terms1and4 + term3 + terms2and5)
        where
          tT = tauTilde state
          nT = nuTilde state
          tM = tauMinus state
          mM = muMinus state
          oneOverTauMinus = mapVector (1/) $ tM
          sumLog = sum $ toList $ mapVector log $ takeDiag l
          terms1and4 = 
              (0.5 * (sum $ toList $ mapVector (log . (1+)) 
                      (zipVectorWith (*) tT oneOverTauMinus))) - sumLog
          term3 = 
              sum $ toList $ 
              mapVector logPhi (zipVectorWith (/) 
                                (zipVectorWith (*) t mM) 
                                (mapVector (sqrt . (1+)) oneOverTauMinus))
          rootSTilde = diag $ mapVector sqrt tT
          vM = generalSolve lowerSolve l 
               (preMultiply (mapVector sqrt tT) k)
          --      vM = generalSolve upperSolve (trans l) (rootSTilde <> k)    
          invTPlusSTilde = 
              diag $ mapVector (1/) $ tT + tM
          part1 = 
              0.5 * ((asRow nT) <> 
                     (k - ((trans vM) <> vM) - invTPlusSTilde) <> 
                     (asColumn nT))
          part2 = 
              0.5 * ((asRow mM) <> (diag tM) <> invTPlusSTilde <> 
                     ((asColumn (zipVectorWith (*) tT mM)) - 
                      (asColumn $ (scale 2 nT))))
          terms2and5 = (part1 + part2) @@> (0,0)  

-- | As we're hiding the state using the State monad, we make a state 
--   transformer that uses updateAllSites and recomputeApproximation to 
--   do a complete single update. This will make use of an arbitrary 
--   state transformer to produce a list specifying the order to update 
--   the sites in. The output is the l matrix produced when recomputing the 
--   approximation.
doOneUpdate :: CovarianceMatrix    
            -> Targets            
            -> SiteOrder           -- ^ Supplier of update order. 
            -> State EPState EPValue
doOneUpdate k t siteOrder = do
  order <- siteOrder
  (state,g,i) <- get
  let sites = dim $ tauTilde state
  let state' = updateAllSites t sites order state
  let finalTT =  tauTilde state' 
  let finalNT = nuTilde state'
  let finalTM = tauMinus state'
  let finalMM = muMinus state'
  let (l, finalVar, finalMu) = recomputeApproximation k sites finalTT finalNT
  let state'' = EPSiteState finalVar finalTT finalMu finalNT finalTM finalMM
  put (state'', g, i+1) 
  let logML = gpClassifierEPEvidence k t l state''
  return $ EPValue logML state'' (i+1)
  
-- | The learning algorithm. Takes an arbitrary function for convergence 
--   testing.
gpClassifierEPLearn :: CovarianceMatrix 
                    -> Targets         
                    -> SiteOrder
                    -> EPConvergenceTest
                    -> (EPValue, EPState)
gpClassifierEPLearn k t siteOrder converged = 
    runState (iterateToConvergence'' doOnce converged) start
        where
          doOnce = doOneUpdate k t siteOrder 
          start = ((generateInitialSiteState k (dim t)), mkStdGen 0, 0) 
 
-- | Prediction with GP classifiers based on EP learning.
--   Takes a matrix in which each row is an example to be 
--   classified.
gpClassifierEPPredict :: (CovarianceFunction c) => EPSiteState
                      -> Inputs          
                      -> Targets          -- ^ Inputs in training set
                      -> CovarianceMatrix
                      -> c                -- ^ Covariance Function
                      -> Inputs           -- ^ New inputs
                      -> DVector
gpClassifierEPPredict state i t k c xStars 
    = fromList $ map phiIntegral (zipWith (/) fStar (map (sqrt . (1+)) vfStar))
      where
        nT = nuTilde state 
        tT = tauTilde state 
        d = dim t
        rootSTildeV = mapVector sqrt tT
        rootSTilde = diag rootSTildeV
        l = trans $ chol ((ident d) + (abaDiagDiag rootSTildeV k))
        z = rootSTilde 
            <> (asColumn 
                (upperSolve (trans l) 
                 (lowerSolve l (flatten $ 
                                (rootSTilde <> k <> (asColumn nT))))))
        xStarsRows = toRows xStars
        covarianceWithTestInputs = 
            fromRows [covarianceWithPoint c i xStar | xStar <- xStarsRows]
        fStar = 
            toList $ flatten $ 
                       (covarianceWithTestInputs <> ((asColumn nT) - z))
        v = [lowerSolve l (rootSTilde <> kxStar) | 
             kxStar <- (toColumns $ trans $ covarianceWithTestInputs)]
        vTv = zipWith (<.>) v v
        kxStarxStar = zipWith (covariance c) xStarsRows xStarsRows
        vfStar = zipWith (-) kxStarxStar vTv

-- | Compute the log evidence and its first derivative for the EP approximation 
--   for GP classification. Targets should be +1/-1. Outputs the -log 
--   marginal likelihood and a vector of its derivatives.
gpClassifierEPLogEvidence :: (CovarianceFunction c) => c -- ^ Covariance
                          -> Inputs 
                          -> Targets 
                          -> SiteOrder
                          -> EPConvergenceTest
                          -> (Double, DVector)
gpClassifierEPLogEvidence c i t siteOrder converged 
  = (-logEvidence, -(zipVectorWith (*) (trueHyper c) (fromList dLogEvidence))) 
    where
      d = dim t
      k = covarianceMatrix c i
      (value, s) = gpClassifierEPLearn k t siteOrder converged
      nT = nuTilde $ siteState value
      tT =  tauTilde $ siteState value
      logEvidence =  eValue value
      rootSTildeV = mapVector sqrt tT
      rootSTilde = diag rootSTildeV
      l = trans $ chol ((ident d) + (abaDiagDiag rootSTildeV k))
      b = (asColumn nT) - 
           (rootSTilde <> 
            (asColumn $ upperSolve (trans l) 
                         (lowerSolve l (flatten $ 
                                     rootSTilde <> k <> (asColumn nT)))))
      r = (b <> (trans b)) - 
             (rootSTilde <> 
                 (generalSolve upperSolve (trans l) 
                                       (generalSolve lowerSolve l rootSTilde)))
      dLogEvidence = map ((0.5 *) . sum . toList . (abDiagOnly r)) 
                           (makeMatricesFromPairs (dCovarianceDParameters c) i)

-- | Essentially the same as gpClassifierEPLogEvidence, but makes a 
--   covariance function using the hyperparameters supplied in a list 
--   and passes it on. 
gpClassifierEPLogEvidenceList :: (CovarianceFunction c) => Inputs 
                          -> Targets
                          -> c -- ^ Covariance 
                          -> SiteOrder
                          -> EPConvergenceTest
                          -> [Double]
                          -> (Double, DVector)
gpClassifierEPLogEvidenceList i t cov siteOrder converged hyper = 
    gpClassifierEPLogEvidence cov2 i t siteOrder converged
        where 
          cov2 = makeCovarianceFromList cov hyper

-- | Essentially the same as gpClassifierEPLogEvidence, but makes a 
--   covariance function using the hyperparameters supplied in a vector 
--   and passes it on. 
gpClassifierEPLogEvidenceVec :: (CovarianceFunction c) => Inputs 
                          -> Targets 
                          -> c -- ^ Covariance
                          -> SiteOrder
                          -> EPConvergenceTest
                          -> DVector
                          -> (Double, DVector)
gpClassifierEPLogEvidenceVec i t cov siteOrder converged hyper = 
    gpClassifierEPLogEvidence cov2 i t siteOrder converged
        where 
          cov2 = makeCovarianceFromList cov (toList hyper)