-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
--     http://hackage.haskell.org/trac/ghc/wiki/WorkingConventions#Warnings
-- for details  

-----------------------------------------------------------------------------
-- |
-- Module      :  LDA
-- Copyright   :  (c) Lennart Schmitt
-- License     :  BSD-style (see the file libraries/base/LICENSE)
-- 
-- Maintainer  :  lennart...schmitt@<nospam>gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- This module implements some linear discriminant analysis functions.
-- Imagine you've made a poll and now you have values/attributes from every subscriber. 
-- Further more you've grouped the subscribers into clusters.
-- The poll-datas are structured as follows: 
--
--  * poll-data of one subscriber = [value] --> Vector value
--
--  * poll-data of one cluster/group of subscribers = [[values]] --> Matrix values
--
--  * poll-data of all clusters/groups = [[[values]]] --> MatrixList values
--
-- Now you want to check if you clustered right and/or how significant the values you asked for are...
--
-----------------------------------------------------------------------------

module Numeric.Statistics.LDA (
       fisher,
       fisher',
       fisherT,
       fisherAll,
       fisherClassificationFunction,
       aprioriProbability,
       discriminantCriteria,
       isolatedDiscriminant
) where

import Numeric.Matrix
import Numeric.MatrixList
import Numeric.Vector 
import Numeric.Function
import Numeric.LinearAlgebra.LAPACK


-- | Calculates the difference between every element and the average of the matrixes row.
--
-- > diffAverage [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[[-1.0,1.0],[0.0,0.0]],[[-1.0,1.0],[-2.0,2.0]]]
diffAverage :: RawMatrixList Double -> RawMatrixList Double
diffAverage xs = ( zipWith $ zipWith (\ _xs y -> map (\x -> (-)x y) _xs ) ) xs (averages xs)

-- | Calculates the square of the difference ("diffAverage").
--
-- > squareDiffAverage [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[[1.0,1.0],[0.0,0.0]],[[1.0,1.0],[4.0,4.0]]]
squareDiffAverage :: RawMatrixList Double -> RawMatrixList Double
squareDiffAverage = ( mapElements (^2) )  .  diffAverage

-- | Calculates the average for every matrix/group.
--
-- > totalAverages [[[-1,1],[2,2]],[[1,3],[4,8]]] == [1.5,3.5]
totalAverages :: RawMatrixList Double -> RawVector Double
totalAverages xs = map (flip(/) ( sum  (countRows xs) )) ( map sum $ zipWith ( \x ys -> map ((*)x) ys) (countRows xs) (transpose $ averages $ transposeAll xs ) )

-- | Calculates the sum of the differences from the averages.
-- 
-- > sumOfAverages [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[2.0,0.0,0.0,0.0],[2.0,4.0,4.0,8.0]]
sumOfAverages :: RawMatrixList Double -> RawMatrix Double
sumOfAverages xss = ( foldVectors sum (map (\xs -> [zipWith (*) a b | x <- [0..((round (count xs))-1)] , y <- [0..((round (count xs))-1)], let a= xs !! x, let b =  xs !! y]) ( diffAverage xss )))

-- | Calculates the spread within the cluster/group/matrix.
-- 
-- > spreadWithinGroups [[[-1,1],[2,2]],[[1,3],[4,8]]] == (2><2) [ 9.0, 9.0, 9.0, 13.0 ]
spreadWithinGroups :: RawMatrixList Double -> Matrix Double
spreadWithinGroups = fromListToQuadraticMatrix  .  (zipAllWith sum)  .  sumOfAverages  .  transposeAll

-- | Calculates the spread from total average.
-- 
-- > spreadFromTotalAverages [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[-1.0,-2.0],[1.0,2.0]]
spreadFromTotalAverages :: RawMatrixList Double -> RawMatrix Double
spreadFromTotalAverages x = map (zipWith (flip(-)) xs) xss
                            where
                            xs = (totalAverages x)
                            xss = (averages.transposeAll $ x)

-- | Calculates the spread between the groups/clusters.
-- 
-- > spreadBetweenGroups [[[-1,1],[2,2]],[[1,3],[4,8]]] == (2><2) [ 4.0, 8.0, 8.0, 16.0 ]
spreadBetweenGroups :: RawMatrixList Double -> Matrix Double
spreadBetweenGroups = fromListToQuadraticMatrix . (zipAllWith sum). d
                    where
                    d xss = [x|i <- [0..((count (spreadFromTotalAverages xss))-1)], let x = c ((spreadFromTotalAverages xss) !!i) ((b xss)!!i)]
                    c xs ys = ([a*b | x<- [0..((count xs)-1)],y<- [0..((count ys)-1)], let a= xs !! x, let b= ys !!y])
                    b xs = map (zipWith (*) (countRows xs) ) $ spreadFromTotalAverages xs


-- | Calculates the isolated discriminants of every attribute.
-- 
-- > isolatedDiscriminant [[[-1,1],[2,2]],[[1,3],[4,8]]] == [0.4444444444444444,1.2307692307692308]
isolatedDiscriminant :: RawMatrixList Double -> RawVector Double
isolatedDiscriminant xss = [b/w | i<- [0..c], let b = bM @@> (i,i), let w = wM @@> (i,i) ]
                            where
                            bM  = spreadBetweenGroups xss
                            wM  = spreadWithinGroups xss
                            c   = count xss -1

{- -----------------------------------------------
----------------------Utilities ------------------
------------------------------------------------ -}
                             
-- | Calculates a analysis-matrix.
calcA :: RawMatrixList Double -> Matrix Double
calcA xss = cA (spreadWithinGroups xss) (spreadBetweenGroups xss)
            where
            cA w b = multiplyR (inv w) b

-- | Calculation of the scaling factor of a matrix.
scalingFactor :: RawMatrixList Double -> Double
scalingFactor xsss = 1 / s
                         where
                         s = sqrt (v'Wv / (fallzahl-gruppen))
                         v'Wv = head $ head $ (Numeric.Matrix.toLists) (v `multiplyR` w `multiplyR` (trans v))
                         fallzahl = countAllRows xsss
                         gruppen = countMatrixes xsss
                         v = asRow.head . toRows . trans . calcA $ xsss
                         w = spreadWithinGroups xsss

-- | Calculates the scaled discriminant coefficient.
scaledDiscriminantCoefficient :: Matrix Double -> Double -> Matrix Double
scaledDiscriminantCoefficient v s = scalarMultiplication s v

-- | Calculates the constant element of the discriminant function.
constElement :: Matrix Double -> RawVector Double -> Double
constElement b x = -1 * (sum $ zipWith (*) (toList . flatten $ b) x)

-- | Calculation of possible discriminant functions.
scaledDiscriminantFunctions :: RawMatrixList Double -> RawVector (LinFunction Double)
scaledDiscriminantFunctions xss = map (\x -> (cE x):x) normDiskKoefs
                                      where
                                      normDiskKoefs = (Numeric.Matrix.toLists).trans $ scaledDiscriminantCoefficient (calcA xss) (scalingFactor xss)
                                      totAvg = totalAverages xss
                                      cE xs = (flip) constElement totAvg (fromLists [xs])

-- | Calculates the scaled discriminant for an attribute by using a scaled discriminant function.
scaledDiscriminant :: LinFunction Double -> Values Double -> Double
scaledDiscriminant v x = calcLinFunction v x

-- | Calculates all scaled discriminants.
scaledDiscriminants :: RawMatrixList Double -> RawMatrixList Double
scaledDiscriminants xss = map (\linFun -> foldVectors (scaledDiscriminant linFun) xss) normDiskFs
                              where
                              normDiskFs = scaledDiscriminantFunctions xss

-- | Calculation of the centroid of a group/cluster by using the discriminant function an the groups values.
centroid :: LinFunction Double -> RawVector(Values Double) -> Double
centroid v m = (sum $ map (scaledDiscriminant v ) m) / (count m)

-- | Calculate all centroids
centroids :: RawMatrixList Double -> RawMatrix Double
centroids xss =  map (\linFun -> map (centroid linFun) xss) normDiskFs
                where
                normDiskFs = scaledDiscriminantFunctions xss

-- | Calculates the averages of the scaled discriminants per discriminant function.
averageOfScaledDiscriminants :: RawMatrixList Double -> RawVector Double
averageOfScaledDiscriminants xss = map (\x -> (foldl1 (+) x) / (count x) )  $ map concat $ scaledDiscriminants xss

-- | Calculates the total spread between all groups.
totalSpreadBetweenGroups :: RawMatrixList Double -> RawVector Double
totalSpreadBetweenGroups xss = [s y and |i <- [0..((count avgNormDisks)-1)], let y = ys !! i, let and = avgNormDisks !! i]
                                where
                                s y and = sum [i*(^2) w| g <- [0..((count is)-1)], let i = is !! g, let w = ((y !! g)-and ) ]
                                is = countRows xss
                                avgNormDisks = averageOfScaledDiscriminants xss
                                ys = centroids xss

-- | Calculates the total spread within the groups.
totalSpreadWithinGroups :: RawMatrixList Double -> RawVector Double
totalSpreadWithinGroups xss = [s y nd | i <- [0..((count normDisks)-1)], let y = ys !! i, let nd = normDisks !! i]
                              where
                              s y and = sum $ map sum [map (\x-> (^2) (x-yy)) gnd | i <- [0..((count y)-1)], let yy = y !! i, let gnd = and !! i]
                              normDisks = scaledDiscriminants xss
                              ys = centroids xss

-- | Calculates the discriminant criteria.            
discriminantCriteria :: RawMatrixList Double -> RawVector Double
discriminantCriteria xss = [ssb/ssw | i <- [0..((count ssbs)-1)], let ssb = ssbs !! i, let ssw = ssws !! i]
                            where
                            ssbs = totalSpreadBetweenGroups xss
                            ssws = totalSpreadWithinGroups xss

-- | Calculation of the a priori probability, more precisely the probability that an element belongs to a group.
aprioriProbability :: RawMatrixList Double -> RawVector Double
aprioriProbability xsss = [ i / i_ | g <- [0..g_], let i = is !! round g]
                                    where
                                    g_ = countMatrixes xsss-1
                                    is = countRows xsss
                                    i_ = countAllRows xsss

-- | Calculates the constant part of the classification function according to Fisher.
fisherClassificationFunctionConst :: RawMatrixList Double -> RawVector Double
fisherClassificationFunctionConst xsss = [ -0.5 * s + log p | g' <- [0..g_], let g = round g', let s = (ss g), let p = ps !! g]
                                           where
                                           ss g = sum [b * x | j' <- [0..j_], let j = round j', let b = bg !! g !! j, let x = x_ @@> (g,j)]
                                           g_ = countMatrixes xsss - 1
                                           j_ = countMatrixesCols xsss -1
                                           ps = aprioriProbability xsss
                                           bg = fisherClassificationFunctionVar xsss
                                           x_ = fromLists.map (map (\x -> (foldl1 (+) x) / (count x) )) $ map transpose $ xsss

-- | Calculates the variable parts of the classification function according to Fisher.                    
fisherClassificationFunctionVar :: RawMatrixList Double -> RawVector (LinFunction Double)
fisherClassificationFunctionVar xsss = [ b g | g <- [0..g_] ]
                                        where
                                        b g = [ig * b_ (round j) (round g) | j <- [0..j_] ]
                                        b_ j g =  sum [ (w * x) | rr <- [0..j_], let r = round rr, let w = w' @@> (r,j), let x = x_ @@> (g,r) ]
                                        ig = i_ - g_
                                        is = countRows xsss
                                        i_ = sum is -1
                                        g_ = countMatrixes xsss -1
                                        j_ = countMatrixesCols xsss -1
                                        w' = inv.trans.spreadWithinGroups $ xsss
                                        x_ = fromLists.map (map (\x -> (foldl1 (+) x) / (count x) )) $ map transpose $ xsss

-- | Calculates the classification function according to Fisher.  
fisherClassificationFunction :: RawMatrixList Double -> RawVector (LinFunction Double)
fisherClassificationFunction xsss = zipWith (:) ( fisherClassificationFunctionConst xsss) (fisherClassificationFunctionVar xsss)

-- | Calculation of the classification of a survey (or attributes) in a cluster. The function takes a vector/list of attributes/values and a context. The context consists of groups/clusters and its items values/attributes. The function returns the ID (starting with 0) of the cluster to which the given vector/list belongs to. This function uses the Fisher algorithm.
fisher :: RawMatrixList Double -> RawVector Double -> Int
fisher xsss attributes = fisher' (fisherClassificationFunction xsss) attributes 

-- | Calculates the ID of the cluster the given values belonging to. This function takes a list of clusters, representated by a tuple, and a list of values. The cluster-tuples consists of a ID of the cluster and the classification function (according to Fisher) of the cluster. This function uses the Fisher algorithm.
fisherT :: RawVector (Int, LinFunction Double) -> RawVector Double -> Int
fisherT clusterTupels obj =  fst $ clusterTupels !! (fisher' (map snd clusterTupels) obj)

-- | Calculates the ID (starting with 0) of the cluster the given list of attributes belongs to. The function takes a list of attributes and a list of clusters which are representated by there classification function. This function uses the Fisher algorithm.
fisher' :: RawVector (LinFunction Double) -> RawVector Double -> Int
fisher' clusterFunctions obj = maxPos $ map (flip calcLinFunction $ obj) clusterFunctions

-- | Calculates the cluster of every survey of a poll. This function takes the data of a whole poll and classifies every survey of the poll. This function uses the Fisher algorithm.
fisherAll :: RawMatrixList Double -> RawMatrix Int
fisherAll xsss = foldVectors (\a -> fisher xsss a) xsss