{- | Gaussian Process Library. This module contains the definition 
   of the standard squared exponential covariance function, extended 
   for use with Automatic Relevance Determination.

   s_f^2 exp (-1\/2 (x_1 - x_2)^T M (x_1 - x_2)) 

   Parameters: s_f^2 and vector containing the diagonal of M. 
   M is diag (1\/l_1^2,...,1\/l_?^2)

   Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- HasGP is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   HasGP is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Covariance.SquaredExpARD 
   ( 
     SquaredExponentialARD(..)
   ) where

import Numeric.LinearAlgebra
import HasGP.Types.MainTypes
import HasGP.Support.Linear
import HasGP.Support.Functions
import HasGP.Covariance.Basic

data SquaredExponentialARD = SquaredExponentialARD 
    {
      fARD            :: Double,
      m               :: DVector
    }

instance CovarianceFunction SquaredExponentialARD where

    trueHyper se = mapVector exp $ join [fromList [fARD se], m se]
      
    covariance se x1 x2 = f2 * (exp ((-(1/2)) * (xAxDiag diff newM2)))
        where
          diff = x1 - x2
          f2 = exp (fARD se)
          newM2 = mapVector ((^^(-2)) . exp) (m se) 

    dCovarianceDParameters se x1 x2 = 
        join [(fromList [dKDLogF]), dKDLogM] -- You need to compute dK/dtheta, 
                                             -- NOT dK/dlogtheta
            where
               diff = x1 - x2
               d = mapVector square diff
               f2 = exp (fARD se)
               m2 = mapVector exp (m se)
               newM2 = mapVector (^^(-2)) m2
               dKDLogF = exp ((-(1/2)) * (xAxDiag diff newM2))
               dKDLogM = scale (f2 * dKDLogF) 
                            (zipVectorWith (*) d (mapVector (^^(-3)) m2)) 

    makeCovarianceFromList se (f:rest) =  
        if (length rest) == (dim $ (m se))
        then SquaredExponentialARD f (fromList rest) 
        else error "SquaredExpARD needs the correct number of hyperparameters"
    makeCovarianceFromList se _ = 
        error "SquaredExpARD needs the correct number of hyperparameters"

    makeListFromCovariance se = (fARD se):(toList $ m se)