```-- While working on this module you are encouraged to remove it and fix
-- any warnings in the module. See
-- for details

-----------------------------------------------------------------------------
-- |
-- Module      :  LDA
-- Copyright   :  (c) Lennart Schmitt
-- License     :  BSD-style (see the file libraries/base/LICENSE)
--
-- Maintainer  :  lennart...schmitt@<nospam>gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- This module implements some linear discriminant analysis functions.
-- Imagine you've made a poll and now you have values/attributes from every subscriber.
-- Further more you've grouped the subscribers into clusters.
-- The poll-datas are structured as follows:
--
--  * poll-data of one subscriber = [value] --> Vector value
--
--  * poll-data of one cluster/group of subscribers = [[values]] --> Matrix values
--
--  * poll-data of all clusters/groups = [[[values]]] --> MatrixList values
--
-- Now you want to check if you clustered right and/or how significant the values you asked for are...
--
-----------------------------------------------------------------------------

module Numeric.Statistics.LDA (
fisher,
fisher',
fisherT,
fisherAll,
fisherClassificationFunction,
aprioriProbability,
discriminantCriteria,
isolatedDiscriminant
) where

import Numeric.Matrix
import Numeric.MatrixList
import Numeric.Vector
import Numeric.Function
import Numeric.LinearAlgebra.LAPACK

-- | Calculates the difference between every element and the average of the matrixes row.
--
-- > diffAverage [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[[-1.0,1.0],[0.0,0.0]],[[-1.0,1.0],[-2.0,2.0]]]
diffAverage :: RawMatrixList Double -> RawMatrixList Double
diffAverage xs = ( zipWith \$ zipWith (\ _xs y -> map (\x -> (-)x y) _xs ) ) xs (averages xs)

-- | Calculates the square of the difference ("diffAverage").
--
-- > squareDiffAverage [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[[1.0,1.0],[0.0,0.0]],[[1.0,1.0],[4.0,4.0]]]
squareDiffAverage :: RawMatrixList Double -> RawMatrixList Double
squareDiffAverage = ( mapElements (^2) )  .  diffAverage

-- | Calculates the average for every matrix/group.
--
-- > totalAverages [[[-1,1],[2,2]],[[1,3],[4,8]]] == [1.5,3.5]
totalAverages :: RawMatrixList Double -> RawVector Double
totalAverages xs = map (flip(/) ( sum  (countRows xs) )) ( map sum \$ zipWith ( \x ys -> map ((*)x) ys) (countRows xs) (transpose \$ averages \$ transposeAll xs ) )

-- | Calculates the sum of the differences from the averages.
--
-- > sumOfAverages [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[2.0,0.0,0.0,0.0],[2.0,4.0,4.0,8.0]]
sumOfAverages :: RawMatrixList Double -> RawMatrix Double
sumOfAverages xss = ( foldVectors sum (map (\xs -> [zipWith (*) a b | x <- [0..((round (count xs))-1)] , y <- [0..((round (count xs))-1)], let a= xs !! x, let b =  xs !! y]) ( diffAverage xss )))

-- | Calculates the spread within the cluster/group/matrix.
--
-- > spreadWithinGroups [[[-1,1],[2,2]],[[1,3],[4,8]]] == (2><2) [ 9.0, 9.0, 9.0, 13.0 ]
spreadWithinGroups :: RawMatrixList Double -> Matrix Double
spreadWithinGroups = fromListToQuadraticMatrix  .  (zipAllWith sum)  .  sumOfAverages  .  transposeAll

-- | Calculates the spread from total average.
--
-- > spreadFromTotalAverages [[[-1,1],[2,2]],[[1,3],[4,8]]] == [[-1.0,-2.0],[1.0,2.0]]
spreadFromTotalAverages :: RawMatrixList Double -> RawMatrix Double
spreadFromTotalAverages x = map (zipWith (flip(-)) xs) xss
where
xs = (totalAverages x)
xss = (averages.transposeAll \$ x)

-- | Calculates the spread between the groups/clusters.
--
-- > spreadBetweenGroups [[[-1,1],[2,2]],[[1,3],[4,8]]] == (2><2) [ 4.0, 8.0, 8.0, 16.0 ]
spreadBetweenGroups :: RawMatrixList Double -> Matrix Double
spreadBetweenGroups = fromListToQuadraticMatrix . (zipAllWith sum). d
where
d xss = [x|i <- [0..((count (spreadFromTotalAverages xss))-1)], let x = c ((spreadFromTotalAverages xss) !!i) ((b xss)!!i)]
c xs ys = ([a*b | x<- [0..((count xs)-1)],y<- [0..((count ys)-1)], let a= xs !! x, let b= ys !!y])
b xs = map (zipWith (*) (countRows xs) ) \$ spreadFromTotalAverages xs

-- | Calculates the isolated discriminants of every attribute.
--
-- > isolatedDiscriminant [[[-1,1],[2,2]],[[1,3],[4,8]]] == [0.4444444444444444,1.2307692307692308]
isolatedDiscriminant :: RawMatrixList Double -> RawVector Double
isolatedDiscriminant xss = [b/w | i<- [0..c], let b = bM @@> (i,i), let w = wM @@> (i,i) ]
where
bM  = spreadBetweenGroups xss
wM  = spreadWithinGroups xss
c   = count xss -1

{- -----------------------------------------------
----------------------Utilities ------------------
------------------------------------------------ -}

-- | Calculates a analysis-matrix.
calcA :: RawMatrixList Double -> Matrix Double
calcA xss = cA (spreadWithinGroups xss) (spreadBetweenGroups xss)
where
cA w b = multiplyR (inv w) b

-- | Calculation of the scaling factor of a matrix.
scalingFactor :: RawMatrixList Double -> Double
scalingFactor xsss = 1 / s
where
s = sqrt (v'Wv / (fallzahl-gruppen))
v'Wv = head \$ head \$ (Numeric.Matrix.toLists) (v `multiplyR` w `multiplyR` (trans v))
fallzahl = countAllRows xsss
gruppen = countMatrixes xsss
v = asRow.head . toRows . trans . calcA \$ xsss
w = spreadWithinGroups xsss

-- | Calculates the scaled discriminant coefficient.
scaledDiscriminantCoefficient :: Matrix Double -> Double -> Matrix Double
scaledDiscriminantCoefficient v s = scalarMultiplication s v

-- | Calculates the constant element of the discriminant function.
constElement :: Matrix Double -> RawVector Double -> Double
constElement b x = -1 * (sum \$ zipWith (*) (toList . flatten \$ b) x)

-- | Calculation of possible discriminant functions.
scaledDiscriminantFunctions :: RawMatrixList Double -> RawVector (LinFunction Double)
scaledDiscriminantFunctions xss = map (\x -> (cE x):x) normDiskKoefs
where
normDiskKoefs = (Numeric.Matrix.toLists).trans \$ scaledDiscriminantCoefficient (calcA xss) (scalingFactor xss)
totAvg = totalAverages xss
cE xs = (flip) constElement totAvg (fromLists [xs])

-- | Calculates the scaled discriminant for an attribute by using a scaled discriminant function.
scaledDiscriminant :: LinFunction Double -> Values Double -> Double
scaledDiscriminant v x = calcLinFunction v x

-- | Calculates all scaled discriminants.
scaledDiscriminants :: RawMatrixList Double -> RawMatrixList Double
scaledDiscriminants xss = map (\linFun -> foldVectors (scaledDiscriminant linFun) xss) normDiskFs
where
normDiskFs = scaledDiscriminantFunctions xss

-- | Calculation of the centroid of a group/cluster by using the discriminant function an the groups values.
centroid :: LinFunction Double -> RawVector(Values Double) -> Double
centroid v m = (sum \$ map (scaledDiscriminant v ) m) / (count m)

-- | Calculate all centroids
centroids :: RawMatrixList Double -> RawMatrix Double
centroids xss =  map (\linFun -> map (centroid linFun) xss) normDiskFs
where
normDiskFs = scaledDiscriminantFunctions xss

-- | Calculates the averages of the scaled discriminants per discriminant function.
averageOfScaledDiscriminants :: RawMatrixList Double -> RawVector Double
averageOfScaledDiscriminants xss = map (\x -> (foldl1 (+) x) / (count x) )  \$ map concat \$ scaledDiscriminants xss

-- | Calculates the total spread between all groups.
totalSpreadBetweenGroups :: RawMatrixList Double -> RawVector Double
totalSpreadBetweenGroups xss = [s y and |i <- [0..((count avgNormDisks)-1)], let y = ys !! i, let and = avgNormDisks !! i]
where
s y and = sum [i*(^2) w| g <- [0..((count is)-1)], let i = is !! g, let w = ((y !! g)-and ) ]
is = countRows xss
avgNormDisks = averageOfScaledDiscriminants xss
ys = centroids xss

-- | Calculates the total spread within the groups.
totalSpreadWithinGroups :: RawMatrixList Double -> RawVector Double
totalSpreadWithinGroups xss = [s y nd | i <- [0..((count normDisks)-1)], let y = ys !! i, let nd = normDisks !! i]
where
s y and = sum \$ map sum [map (\x-> (^2) (x-yy)) gnd | i <- [0..((count y)-1)], let yy = y !! i, let gnd = and !! i]
normDisks = scaledDiscriminants xss
ys = centroids xss

-- | Calculates the discriminant criteria.
discriminantCriteria :: RawMatrixList Double -> RawVector Double
discriminantCriteria xss = [ssb/ssw | i <- [0..((count ssbs)-1)], let ssb = ssbs !! i, let ssw = ssws !! i]
where
ssbs = totalSpreadBetweenGroups xss
ssws = totalSpreadWithinGroups xss

-- | Calculation of the a priori probability, more precisely the probability that an element belongs to a group.
aprioriProbability :: RawMatrixList Double -> RawVector Double
aprioriProbability xsss = [ i / i_ | g <- [0..g_], let i = is !! round g]
where
g_ = countMatrixes xsss-1
is = countRows xsss
i_ = countAllRows xsss

-- | Calculates the constant part of the classification function according to Fisher.
fisherClassificationFunctionConst :: RawMatrixList Double -> RawVector Double
fisherClassificationFunctionConst xsss = [ -0.5 * s + log p | g' <- [0..g_], let g = round g', let s = (ss g), let p = ps !! g]
where
ss g = sum [b * x | j' <- [0..j_], let j = round j', let b = bg !! g !! j, let x = x_ @@> (g,j)]
g_ = countMatrixes xsss - 1
j_ = countMatrixesCols xsss -1
ps = aprioriProbability xsss
bg = fisherClassificationFunctionVar xsss
x_ = fromLists.map (map (\x -> (foldl1 (+) x) / (count x) )) \$ map transpose \$ xsss

-- | Calculates the variable parts of the classification function according to Fisher.
fisherClassificationFunctionVar :: RawMatrixList Double -> RawVector (LinFunction Double)
fisherClassificationFunctionVar xsss = [ b g | g <- [0..g_] ]
where
b g = [ig * b_ (round j) (round g) | j <- [0..j_] ]
b_ j g =  sum [ (w * x) | rr <- [0..j_], let r = round rr, let w = w' @@> (r,j), let x = x_ @@> (g,r) ]
ig = i_ - g_
is = countRows xsss
i_ = sum is -1
g_ = countMatrixes xsss -1
j_ = countMatrixesCols xsss -1
w' = inv.trans.spreadWithinGroups \$ xsss
x_ = fromLists.map (map (\x -> (foldl1 (+) x) / (count x) )) \$ map transpose \$ xsss

-- | Calculates the classification function according to Fisher.
fisherClassificationFunction :: RawMatrixList Double -> RawVector (LinFunction Double)
fisherClassificationFunction xsss = zipWith (:) ( fisherClassificationFunctionConst xsss) (fisherClassificationFunctionVar xsss)

-- | Calculation of the classification of a survey (or attributes) in a cluster. The function takes a vector/list of attributes/values and a context. The context consists of groups/clusters and its items values/attributes. The function returns the ID (starting with 0) of the cluster to which the given vector/list belongs to. This function uses the Fisher algorithm.
fisher :: RawMatrixList Double -> RawVector Double -> Int
fisher xsss attributes = fisher' (fisherClassificationFunction xsss) attributes

-- | Calculates the ID of the cluster the given values belonging to. This function takes a list of clusters, representated by a tuple, and a list of values. The cluster-tuples consists of a ID of the cluster and the classification function (according to Fisher) of the cluster. This function uses the Fisher algorithm.
fisherT :: RawVector (Int, LinFunction Double) -> RawVector Double -> Int
fisherT clusterTupels obj =  fst \$ clusterTupels !! (fisher' (map snd clusterTupels) obj)

-- | Calculates the ID (starting with 0) of the cluster the given list of attributes belongs to. The function takes a list of attributes and a list of clusters which are representated by there classification function. This function uses the Fisher algorithm.
fisher' :: RawVector (LinFunction Double) -> RawVector Double -> Int
fisher' clusterFunctions obj = maxPos \$ map (flip calcLinFunction \$ obj) clusterFunctions

-- | Calculates the cluster of every survey of a poll. This function takes the data of a whole poll and classifies every survey of the poll. This function uses the Fisher algorithm.
fisherAll :: RawMatrixList Double -> RawMatrix Int
fisherAll xsss = foldVectors (\a -> fisher xsss a) xsss
```