{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternGuards #-}

module Data.CRF.Chain2.Generic.Train
( CodecSpec (..)
, train
) where

import System.IO (hSetBuffering, stdout, BufferMode (..))
import Control.Applicative ((<$>))
import Data.Maybe (maybeToList)
import qualified Data.Vector as V
import qualified Numeric.SGD as SGD
import qualified Numeric.SGD.LogSigned as L

import Data.CRF.Chain2.Generic.Internal
import Data.CRF.Chain2.Generic.FeatMap
import Data.CRF.Chain2.Generic.External (SentL)
import Data.CRF.Chain2.Generic.Model
import Data.CRF.Chain2.Generic.Inference (expectedFeatures, accuracy)

-- | A codec specification.
data CodecSpec a b c o t = CodecSpec
    { mkCodec :: [SentL a b] -> (c, [(Xs o t, Ys t)])
    , encode  :: c -> [SentL a b] -> [(Xs o t, Ys t)] }

-- | Train the CRF using the stochastic gradient descent method.
-- When the evaluation data 'IO' action is 'Just', the iterative
-- training process will notify the user about the current accuracy
-- on the evaluation part every full iteration over the training part.
-- TODO: Add custom feature extraction function.
train
    :: (Ord a, Ord b, Eq t, Ord f, FeatMap m f)
    => SGD.SgdArgs                  -- ^ Args for SGD
    -> CodecSpec a b c o t          -- ^ Codec specification
    -> FeatGen o t f                -- ^ Feature generation
    -> FeatSel o t f                -- ^ Feature selection
    -> IO [SentL a b]               -- ^ Training data 'IO' action
    -> Maybe (IO [SentL a b])       -- ^ Maybe evalation data
    -> IO (c, Model m o t f)        -- ^ Resulting codec and model
train sgdArgs CodecSpec{..} ftGen ftSel trainIO evalIO'Maybe = do
    hSetBuffering stdout NoBuffering
    (codec, trainData) <- mkCodec <$> trainIO
    evalDataM <- case evalIO'Maybe of
        Just evalIO -> Just . encode codec <$> evalIO
        Nothing     -> return Nothing
    let crf = mkModel ftGen ftSel trainData
    para <- SGD.sgdM sgdArgs
        (notify sgdArgs crf trainData evalDataM)
        (gradOn crf) (V.fromList trainData) (values crf)
    return (codec, crf { values = para })

gradOn
    :: FeatMap m f => Model m o t f
    -> SGD.Para -> (Xs o t, Ys t) -> SGD.Grad
gradOn crf para (xs, ys) = SGD.fromLogList $
    [ (ix, L.fromPos val)
    | (ft, val) <- presentFeats (featGen curr) xs ys
    , FeatIx ix <- maybeToList (index curr ft) ] ++
    [ (ix, L.fromNeg val)
    | (ft, val) <- expectedFeatures curr xs
    , FeatIx ix <- maybeToList (index curr ft) ]
  where
    curr = crf { values = para }

notify
    :: (Eq t, FeatMap m f) => SGD.SgdArgs -> Model m o t f -> [(Xs o t, Ys t)]
    -> Maybe [(Xs o t, Ys t)] -> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} crf trainData evalDataM para k 
    | doneTotal k == doneTotal (k - 1) = putStr "."
    | Just dataSet <- evalDataM = do
        let x = accuracy (crf { values = para }) dataSet
        putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = " ++ show x)
    | otherwise =
        putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = #")
  where
    doneTotal :: Int -> Int
    doneTotal = floor . done
    done :: Int -> Double
    done i
        = fromIntegral (i * batchSize)
        / fromIntegral trainSize
    trainSize = length trainData