{- | Normalise is a module in the HasGP Gaussian process library. 
     It contains functions for performing basic normalisation 
     tasks on training examples, and for computing assorted 
     standard statistics.

     Copyright (C) 2011 Sean Holden. sbh11\@cl.cam.ac.uk.
-}
{- This file is part of HasGP.

   HasGP is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   HasGP is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with HasGP.  If not, see <http://www.gnu.org/licenses/>.
-}
module HasGP.Data.Normalise where

import Numeric.LinearAlgebra

import HasGP.Types.MainTypes
import HasGP.Support.Linear as L
import HasGP.Support.Functions as F

-- | Compute the mean for each attribute in a set of examples. 
exampleMean :: Inputs  -- ^ Matrix - one row per example
            -> DVector -- ^ Vector of means for each attribute.
exampleMean examples = 
    fromList $ map (L.sumVectorDiv r) (toColumns examples)
        where
          r = rows examples

-- | Compute the variance for each attribute in a set of examples.
exampleVariance :: Inputs  -- ^ Matrix - one row per example
                -> DVector -- ^ Vector of variances for each attribute.
exampleVariance examples = 
    fromList $ map (L.sumVectorDiv r) 
    (toColumns $ fromRows $ map ((mapVector (^2)) . (\x -> x-m)) 
     (toRows examples))
        where
          r = rows examples
          m = exampleMean examples

-- | Compute the mean and variance for each attribute in a set of examples.
exampleMeanVariance :: Inputs               -- ^ Matrix - one row per example
                    -> (DVector, DVector)   -- ^ Means and variances
exampleMeanVariance examples = (exampleMean examples, exampleVariance examples) 
      
-- | Normalise a set of examples to have specified mean and variance.
normaliseMeanVariance :: DVector -- ^ Vector of new means required
                      -> DVector -- ^ Vector of new variances required
                      -> Inputs  -- ^ Matrix - one row per example
                      -> Inputs  -- ^ Normalised matrix
normaliseMeanVariance newMean newVariance examples = 
    fromRows $ map (\x -> x+newMean) varianceAdjusted
    where
      (m, v) = exampleMeanVariance examples
      zeroMean = map (\x -> x-m) (toRows examples)
      varianceAdjustment = zipVectorWith (\x y -> (sqrt x)/(sqrt y)) 
                                         newVariance v
      varianceAdjusted = map (zipVectorWith (*) varianceAdjustment) zeroMean

-- | The same as normaliseMeanVariance but every column (attribute) is 
--   normalised in the same way.
normaliseMeanVarianceSimple :: Double  -- ^ New mean required 
                            -> Double  -- ^ New variance required
                            -> Inputs  -- ^ Matrix - one row per example
                            -> Inputs  -- ^ Normalised matrix
normaliseMeanVarianceSimple newMean newVariance examples = 
    normaliseMeanVariance (constant newMean c) (constant newVariance c) examples
        where
          c = cols examples

-- | Normalise a set of examples to have specified maximum and minimum.
normaliseBetweenLimits :: Double -- ^ New min required 
                       -> Double -- ^ New max required 
                       -> Inputs -- ^ Matrix - one row per example
                       -> Inputs -- ^ Normalised matrix
normaliseBetweenLimits min max examples = 
    fromColumns $ zipWith (\x y -> mapVector (x+) y) 
                cV (zipWith scale mV columns)
        where
          columns = toColumns examples
          minV = map minElement columns
          maxV = map maxElement columns
          mV = zipWith (\x y -> ((max - min) / (y - x))) minV maxV
          cV = zipWith (\x y -> (min - (y * x))) minV mV

-- | Find the columns of a matrix in which all values are equal.      
findRedundantAttributes :: Inputs  -- ^ Matrix - one row per example
                        -> [Bool]  -- ^ List - True elements mark redundancy
findRedundantAttributes examples = map allSame columns
    where
      columns = map toList (toColumns examples)
      allSame []        = True
      allSame [h]       = True
      allSame [h1,h2]   = (h1 == h2)
      allSame (h1:h2:t) = (h1 == h2) && (allSame (h2:t)) 

-- | List column numbers for redundant attributes.
listRedundantAttributes :: Inputs -- ^ Matrix - one row per example
                        -> [Int]  -- ^ List - positions of redundant attributes
listRedundantAttributes examples = findColumns boolean 1 []
    where
      boolean = findRedundantAttributes examples
      findColumns [] n result = reverse result
      findColumns (h:t) n result
          | h = findColumns t (n+1) (n:result)
          | otherwise = findColumns t (n+1) result

-- | Remove any redundant columns from a matrix.
removeRedundantAttributes :: Inputs -- ^ Matrix - one row per example
                          -> Inputs -- ^ Modified matrix - one row per example
removeRedundantAttributes examples = 
    fromColumns $ removeTrueColumns [] r (toColumns examples)
        where
          r = findRedundantAttributes examples
          removeTrueColumns result [] []
              = reverse result
          removeTrueColumns result (True:t1) (c:t2)  
              = removeTrueColumns result t1 t2
          removeTrueColumns result (False:t1) (c:t2) 
              = removeTrueColumns (c:result) t1 t2

-- | Specify a list of columns (matrix numbered from 1).
--   Produce a matrix with ONLY those columns in the 
--   order specified in the list.
retainAttributes :: [Int]   -- ^ List of columns to keep.
                 -> Inputs  -- ^ Matrix - one row per example
                 -> Inputs  -- ^ Modified matrix - one row per example
retainAttributes l m = trans $ extractRows l2 $ trans m
    where
      l2 = map (\x -> x-1) l

-- | Compute the numbers for the confusion matrix.
--   It is assumed that classes are +1 (positive) and -1 (negative).
--   Result is (a,b,c,d):
--   a - correct negatives
--   b - predict positive when correct is negative
--   c - predict negative when correct is positive
--   d - correct positives
confusionMatrix :: Targets    
                -> Outputs    
                -> (Double,Double,Double,Double)
confusionMatrix correct predicted = 
    cm (toList correct) (toList predicted) (0,0,0,0)
        where
          cm [] [] result = result
          cm (h1:t1) (h2:t2) (a,b,c,d) = case (h1, h2) of
                                           (1.0, 1.0) -> cm t1 t2 (a,b,c,d+1)
                                           (1.0,-1.0) -> cm t1 t2 (a,b,c+1,d)
                                           (-1.0, 1.0) -> cm t1 t2 (a,b+1,c,d)
                                           (-1.0,-1.0) -> cm t1 t2 (a+1,b,c,d)
          cm _ _ result 
              = error "Correct and predicted vectors must have the same length"

-- | Print the confusion matrix and some other statistics
printConfusionMatrix :: Targets -- ^ Vector of targets 
                     -> Outputs -- ^ Vector of actual outputs
                     -> IO ()
printConfusionMatrix correct predicted = do
  let (a,b,c,d) = confusionMatrix correct predicted
  let n = a+b+c+d
  let trueP = d/(d+c)
  let precision = d/(d+b)
  putStrLn ("------------------------------------------------")
  putStrLn ("Correct -1, Predicted -1: a = " ++ (show a))
  putStrLn ("Correct -1, Predicted +1: b = " ++ (show b))
  putStrLn ("Correct +1, Predicted -1: c = " ++ (show c))
  putStrLn ("Correct +1, Predicted +1: d = " ++ (show d))
  putStrLn ("------------------------------------------------")
  putStrLn ("Number of examples: n = a+b+c+d = " ++ (show n))
  putStrLn ("Accuracy:               a+d/n   = " ++ (show ((a+d)/n)))
  putStrLn ("Recall/True Positive:   d/d+c   = " ++ (show trueP))
  putStrLn ("False Positive:         b/b+a   = " ++ (show (b/(b+a))))
  putStrLn ("True Negative:          a/b+a   = " ++ (show (a/(b+a))))
  putStrLn ("False Negative:         c/d+c   = " ++ (show (c/(d+c))))
  putStrLn ("Precision:              d/d+b   = " ++ (show precision))
  putStrLn ("F Measure (beta = 1)            = " ++ 
            (show ((2 * trueP * precision)/(trueP + precision))))
  putStrLn ("------------------------------------------------")
  return ()

-- | Assuming the labels are +1 or -1, count how many there are of each.
countLabels :: Targets -> IO ()
countLabels v = do
  let d = dim v
  let plus = length $ filter (==(1.0)) $ toList v
  putStrLn ("Total number of labels: " ++ (show d))
  putStrLn ("Number of +1 labels:    " ++ (show plus))
  putStrLn ("Number of -1 labels:    " ++ (show (d - plus)))
  return ()