module Data.CRF.Chain1.Train
( CRF (..)
, train
) where
import Control.Applicative ((<$>), (<*>))
import System.IO (hSetBuffering, stdout, BufferMode (..))
import Data.Binary (Binary, put, get)
import qualified Data.Vector as V
import qualified Numeric.SGD as SGD
import qualified Numeric.SGD.LogSigned as L
import Data.CRF.Chain1.Dataset.Internal
import Data.CRF.Chain1.Dataset.External (SentL)
import Data.CRF.Chain1.Dataset.Codec (mkCodec, Codec, encodeDataL)
import Data.CRF.Chain1.Feature (Feature, featuresIn)
import Data.CRF.Chain1.Model (Model (..), mkModel, FeatIx (..), featToInt)
import Data.CRF.Chain1.Inference (accuracy, expectedFeaturesIn)
data CRF a b = CRF {
codec :: Codec a b,
model :: Model }
instance (Ord a, Ord b, Binary a, Binary b) => Binary (CRF a b) where
put CRF{..} = put codec >> put model
get = CRF <$> get <*> get
train
:: (Ord a, Ord b)
=> SGD.SgdArgs
-> IO [SentL a b]
-> Maybe (b, IO [SentL a b])
-> ([(Xs, Ys)] -> [Feature])
-> IO (CRF a b)
train sgdArgs trainIO evalIO'Maybe extractFeats = do
hSetBuffering stdout NoBuffering
(_codec, trainData) <- mkCodec <$> trainIO
evalDataM <- case evalIO'Maybe of
Just (x, evalIO) -> Just . encodeDataL x _codec <$> evalIO
Nothing -> return Nothing
let crf = mkModel (extractFeats trainData)
para <- SGD.sgdM sgdArgs
(notify sgdArgs crf trainData evalDataM)
(gradOn crf) (V.fromList trainData) (values crf)
return $ CRF _codec (crf { values = para })
gradOn :: Model -> SGD.Para -> (Xs, Ys) -> SGD.Grad
gradOn crf para (xs, ys) = SGD.fromLogList $
[ (featToInt curr feat, L.fromPos val)
| (feat, val) <- featuresIn xs ys ] ++
[ (ix, L.fromNeg val)
| (FeatIx ix, val) <- expectedFeaturesIn curr xs ]
where
curr = crf { values = para }
notify
:: SGD.SgdArgs -> Model -> [(Xs, Ys)] -> Maybe [(Xs, Ys)]
-> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} crf trainData evalDataM para k
| doneTotal k == doneTotal (k 1) = putStr "."
| Just dataSet <- evalDataM = do
let x = accuracy (crf { values = para }) dataSet
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] acc = " ++ show x)
| otherwise =
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] acc = #")
where
doneTotal :: Int -> Int
doneTotal = floor . done
done :: Int -> Double
done i
= fromIntegral (i * batchSize)
/ fromIntegral trainSize
trainSize = length trainData