module Data.CRF.Chain2.Generic.Train
( CodecSpec (..)
, train
) where
import System.IO (hSetBuffering, stdout, BufferMode (..))
import Control.Applicative ((<$>))
import Data.Maybe (maybeToList)
import qualified Data.Vector as V
import qualified Numeric.SGD as SGD
import qualified Numeric.SGD.LogSigned as L
import Data.CRF.Chain2.Generic.Internal
import Data.CRF.Chain2.Generic.FeatMap
import Data.CRF.Chain2.Generic.External (SentL)
import Data.CRF.Chain2.Generic.Model
import Data.CRF.Chain2.Generic.Inference (expectedFeatures, accuracy)
data CodecSpec a b c o t = CodecSpec
{ mkCodec :: [SentL a b] -> (c, [(Xs o t, Ys t)])
, encode :: c -> [SentL a b] -> [(Xs o t, Ys t)] }
train
:: (Ord a, Ord b, Eq t, Ord f, FeatMap m f)
=> SGD.SgdArgs
-> CodecSpec a b c o t
-> FeatGen o t f
-> FeatSel o t f
-> IO [SentL a b]
-> Maybe (IO [SentL a b])
-> IO (c, Model m o t f)
train sgdArgs CodecSpec{..} ftGen ftSel trainIO evalIO'Maybe = do
hSetBuffering stdout NoBuffering
(codec, trainData) <- mkCodec <$> trainIO
evalDataM <- case evalIO'Maybe of
Just evalIO -> Just . encode codec <$> evalIO
Nothing -> return Nothing
let crf = mkModel ftGen ftSel trainData
para <- SGD.sgdM sgdArgs
(notify sgdArgs crf trainData evalDataM)
(gradOn crf) (V.fromList trainData) (values crf)
return (codec, crf { values = para })
gradOn
:: FeatMap m f => Model m o t f
-> SGD.Para -> (Xs o t, Ys t) -> SGD.Grad
gradOn crf para (xs, ys) = SGD.fromLogList $
[ (ix, L.fromPos val)
| (ft, val) <- presentFeats (featGen curr) xs ys
, FeatIx ix <- maybeToList (index curr ft) ] ++
[ (ix, L.fromNeg val)
| (ft, val) <- expectedFeatures curr xs
, FeatIx ix <- maybeToList (index curr ft) ]
where
curr = crf { values = para }
notify
:: (Eq t, FeatMap m f) => SGD.SgdArgs -> Model m o t f -> [(Xs o t, Ys t)]
-> Maybe [(Xs o t, Ys t)] -> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} crf trainData evalDataM para k
| doneTotal k == doneTotal (k 1) = putStr "."
| Just dataSet <- evalDataM = do
let x = accuracy (crf { values = para }) dataSet
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = " ++ show x)
| otherwise =
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = #")
where
doneTotal :: Int -> Int
doneTotal = floor . done
done :: Int -> Double
done i
= fromIntegral (i * batchSize)
/ fromIntegral trainSize
trainSize = length trainData