{-# LANGUAGE RecordWildCards #-}

-- | Passive-aggressive optimization. Mainly based on:
--
-- Zakov, Shay and Goldberg, Yoav and Elhaded, Michael and Ziv-Ukelson, Michal
-- "Rich Parameterization Improves RNA Structure Prediction"
-- RECOMB 2011
--
-- and
--
-- Crammer, Koby and (et al)
-- "Online Passive-Aggressive Algorithms"
-- Journal of Machine Learning Research (2006)
--
-- TODO as always: move out of here and put in its own library

module BioInf.PassiveAggressive where

import qualified Data.Vector.Unboxed as VU
import Data.List as L
import Data.Set as S
import Control.Arrow
import Data.Map as M
import Text.Printf

import Biobase.TrainingData
import BioInf.Keys

import qualified BioInf.Params as P
import qualified BioInf.Params.Import as P
import qualified BioInf.Params.Export as P

import Statistics.ConfusionMatrix
import Statistics.PerformanceMetrics

import Data.PrimitiveArray as PA
import Data.PrimitiveArray.Ix



-- | Default implementation of P/A.

defaultPA :: Double -> P.Params -> TrainingData -> (P.Params,Double,Double,[(Int,Double)])
defaultPA aggressiveness params td@TrainingData{..}
  | L.null $ pOnly++kOnly = (params,0,1,[])
  | sty >= 0.999 = (params,0,1,[])
  | otherwise = ( new
                , tau
                , sty
                , changes
                )
  where
    -- create new vector
    new = P.fromList . VU.toList $ VU.accum (\v pm -> v+pm) cur changes
    pFeatures = featureVector primary predicted
    kFeatures = featureVector primary secondary
    pOnly = pFeatures L.\\ kFeatures
    kOnly = kFeatures L.\\ pFeatures
    numChanges = genericLength $ pOnly ++ kOnly
    changes = zip kOnly (repeat $ negate tau) ++ zip pOnly (repeat tau)
    cur = VU.fromList . P.toList $ params
    pScore = sum . L.map (cur VU.!) $ pFeatures
    kScore = sum . L.map (cur VU.!) $ kFeatures
    -- weight calculation
    tau
      | kScore + epsilon < pScore
          = error $ "S(known) < S(predicted)\n" ++ errorKnownTooGood td cur kFeatures pFeatures
      |  sty >  0.999
      && kScore+epsilon < pScore
          = error $ "S(known) < S(predicted)\n" ++ errorKnownTooGood td cur kFeatures pFeatures
      | sty >= 0.999
          = 0
      | otherwise
          = val
      where
        val = min aggressiveness $ (kScore - pScore + sqrt (1-sty)) / (numChanges ^ 2)
    sty = case fmeasure (mkConfusionMatrix td) of -- currently optimizing using F_1
            Left  _ -> 1
            Right v -> v
    -- special constants
    epsilon = 0.1

-- | In case that the known structure has a score 'epsilon' better than the
-- predicted, we have an error condition, as this should never be the case.

errorKnownTooGood TrainingData{..} curPs kFeatures pFeatures = z where
  z =  printf "S(known) = %7.4f, S(pred) = %7.4f, S(known) - S(pred) = %7.4f\n"
        kScore pScore (kScore - pScore)
    ++ printf "%s\n%s\n" primary (concat $ intersperse "\n" comments)
  kScore = sum . L.map (curPs VU.!) $ kFeatures
  pScore = sum . L.map (curPs VU.!) $ pFeatures

-- | Pull in the statistical interface. From the confusion matrix, we
-- automagically get everything we need.
--
-- NOTE Unfortunately, StatisticalMethods has heavy dependencies.

instance MkConfusionMatrix TrainingData where
  mkConfusionMatrix TrainingData{..} = ConfusionMatrix
    { fn = Right . fromIntegral . S.size $ k `S.difference` p
    , fp = Right . fromIntegral . S.size $ p `S.difference` k
    , tn = Right . fromIntegral $ allPs - S.size (k `S.union` p)
    , tp = Right . fromIntegral . S.size $ k `S.intersection` p
    } where
        k = S.fromList secondary
        p = S.fromList predicted
        allPs = ((length primary) * (length primary -1)) `div` 2