module Data.CRF.Chain1.Constrained.Train
( CRF (..)
, train
) where
import Control.Applicative ((<$>), (<*>))
import System.IO (hSetBuffering, stdout, BufferMode (..))
import Data.Binary (Binary, put, get)
import qualified Data.Set as S
import qualified Data.Map as M
import qualified Numeric.SGD as SGD
import qualified Numeric.SGD.LogSigned as L
import Data.CRF.Chain1.Constrained.Dataset.Internal
import Data.CRF.Chain1.Constrained.Dataset.External (SentL, unknown, unProb)
import Data.CRF.Chain1.Constrained.Dataset.Codec
(mkCodec, Codec, obMax, lbMax, encodeDataL, encodeLabels)
import Data.CRF.Chain1.Constrained.Feature (Feature, featuresIn)
import Data.CRF.Chain1.Constrained.Model
(Model (..), mkModel, FeatIx (..), featToJustInt)
import Data.CRF.Chain1.Constrained.Inference (accuracy, expectedFeaturesIn)
data CRF a b = CRF {
codec :: Codec a b,
model :: Model }
instance (Ord a, Ord b, Binary a, Binary b) => Binary (CRF a b) where
put CRF{..} = put codec >> put model
get = CRF <$> get <*> get
train
:: (Ord a, Ord b)
=> SGD.SgdArgs
-> Bool
-> IO [SentL a b]
-> IO [SentL a b]
-> (AVec Lb -> [(Xs, Ys)] -> [Feature])
-> IO (CRF a b)
train sgdArgs onDisk trainIO evalIO extractFeats = do
hSetBuffering stdout NoBuffering
codec <- mkCodec <$> trainIO
trainData_ <- encodeDataL codec <$> trainIO
SGD.withData onDisk trainData_ $ \trainData -> do
evalData_ <- encodeDataL codec <$> evalIO
SGD.withData onDisk evalData_ $ \evalData -> do
r0 <- encodeLabels codec . S.toList . unkSet <$> trainIO
feats <- extractFeats r0 <$> SGD.loadData trainData
let model = (mkModel (obMax codec) (lbMax codec) feats) { r0 = r0 }
para <- SGD.sgd sgdArgs
(notify sgdArgs model trainData evalData)
(gradOn model) trainData (values model)
return $ CRF codec (model { values = para })
unkSet :: Ord b => [SentL a b] -> S.Set b
unkSet =
S.fromList . concatMap onSent
where
onSent = concatMap onWord
onWord word
| unknown (fst word) = M.keys . unProb . snd $ word
| otherwise = []
gradOn :: Model -> SGD.Para -> (Xs, Ys) -> SGD.Grad
gradOn model para (xs, ys) = SGD.fromLogList $
[ (featToJustInt curr feat, L.fromPos val)
| (feat, val) <- featuresIn xs ys ] ++
[ (ix, L.fromNeg val)
| (FeatIx ix, val) <- expectedFeaturesIn curr xs ]
where
curr = model { values = para }
notify
:: SGD.SgdArgs -> Model
-> SGD.Dataset (Xs, Ys)
-> SGD.Dataset (Xs, Ys)
-> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} model trainData evalData para k
| doneTotal k == doneTotal (k 1) = putStr "."
| SGD.size evalData > 0 = do
x <- accuracy (model { values = para }) <$> SGD.loadData evalData
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = " ++ show x)
| otherwise =
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = #")
where
doneTotal :: Int -> Int
doneTotal = floor . done
done :: Int -> Double
done i
= fromIntegral (i * batchSize)
/ fromIntegral trainSize
trainSize = SGD.size trainData