{- | HasGP Gaussian Process Library. This module contains assorted
     functions that support GP calculations and are specifically
     related to linear algebra.

     Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- This file is part of HasGP.

   HasGP is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   HasGP is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Support.Linear where

import Data.Packed.ST
import Control.Monad (mapM_,zipWithM_)
import Control.Monad.ST

import Numeric.LinearAlgebra

import HasGP.Types.MainTypes
import HasGP.Support.Functions as F

-- | Sum the elements in a vector.
sumVector :: DVector -> Double
sumVector = foldVector (+) 0.0

-- | Sum of elements in a vector, divided by an Int.
sumVectorDiv :: Int -> DVector -> Double
sumVectorDiv d v = (sumVector v)/(fromIntegral d)

-- | Length of a vector.
lengthV :: (Normed a b) => a b -> RealOf b
lengthV = pnorm PNorm2

-- | Generate a vector equal to the first column of a matrix.
toVector :: Matrix Double -> Vector Double
toVector x = head $ toColumns x

-- | Replace the element at a specified position in a vector.
--   NOTE: hmatrix numbers from 0, which is odd. This numbers from 1.
--   The result is returned by overwriting v. This is implemented 
--   via runSTVector because the increase in efficiency is HUGE.
replaceInVector :: DVector -> Int -> Double -> DVector
replaceInVector v i n 
    | (1 <= i) && (i <= (dim v)) = runSTVector $ do
                                     v2 <- thawVector v
                                     writeVector v2 (i-1) n
                                     return v2
    | otherwise                  = error "Index out of range in replaceInVector"
       
-- | Efficiently pre multiply by a diagonal matrix (passed as a vector)
preMultiply :: DVector -> DMatrix -> DMatrix
preMultiply v m = fromRows $ zipWith scale (toList v) (toRows m) 

-- | Efficiently post multiply by a diagonal matrix (passed as a vector)
postMultiply :: DMatrix -> DVector -> DMatrix
postMultiply m v = fromColumns $ zipWith scale (toList v) (toColumns m) 

-- | Compute x^T A x when A is diagonal. The second argument is the 
--   diagonal of A.
xAxDiag :: DVector -> DVector -> Double
xAxDiag x a 
    | (d == dim a) = a <.> (x * x)      
    | otherwise = error "Incorrect dimensions in xAxDiag"
    where
      d = dim x

-- | Compute the diagonal only of the product of two square matrices
abDiagOnly :: DMatrix -> DMatrix -> DVector
abDiagOnly a b = fromList $ zipWith (<.>) (toRows a) (toColumns b)

-- | Compute ABA where A is diagonal. The first argument is the diagonal of A.
abaDiagDiag :: DVector -> DMatrix -> DMatrix
abaDiagDiag a b = (d><d) (zipWith (*) bL aA)
    where
      d = dim a
      aL = toList a
      aA = [(a1 * a2) | a1 <- aL, a2 <- aL]
      bL = toList $ join $ toRows b

-- | Compute aBa where a is a vector and B is a matrix
abaVV :: DVector -> DMatrix -> Double
abaVV a b = (flatten ((asRow a) <> b)) <.> a