{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE PatternGuards #-} module Data.CRF.Chain1.Constrained.Train ( CRF (..) , train ) where import Control.Applicative ((<$>), (<*>)) import System.IO (hSetBuffering, stdout, BufferMode (..)) import Data.Binary (Binary, put, get) import qualified Data.Set as S import qualified Data.Map as M import qualified Numeric.SGD as SGD import qualified Numeric.SGD.LogSigned as L import Data.CRF.Chain1.Constrained.Dataset.Internal import Data.CRF.Chain1.Constrained.Dataset.External (SentL, unknown, unProb) import Data.CRF.Chain1.Constrained.Dataset.Codec (mkCodec, Codec, obMax, lbMax, encodeDataL, encodeLabels) import Data.CRF.Chain1.Constrained.Feature (Feature, featuresIn) import Data.CRF.Chain1.Constrained.Model (Model (..), mkModel, FeatIx (..), featToJustInt) import Data.CRF.Chain1.Constrained.Inference (accuracy, expectedFeaturesIn) -- | A conditional random field model with additional codec used for -- data encoding. data CRF a b = CRF { -- | The codec is used to transform data into internal representation, -- where each observation and each label is represented by a unique -- integer number. codec :: Codec a b, -- | The actual model, which is a map from 'Feature's to potentials. model :: Model } instance (Ord a, Ord b, Binary a, Binary b) => Binary (CRF a b) where put CRF{..} = put codec >> put model get = CRF <$> get <*> get -- | Train the CRF using the stochastic gradient descent method. -- The resulting model will contain features extracted with -- the user supplied extraction function. -- You can use the functions provided by the "Data.CRF.Chain1.Feature.Present" -- and "Data.CRF.Chain1.Feature.Hidden" modules for this purpose. -- When the evaluation data 'IO' action is 'Just', the iterative -- training process will notify the user about the current accuracy -- on the evaluation part every full iteration over the training part. -- TODO: Accept custom r0 construction function. train :: (Ord a, Ord b) => SGD.SgdArgs -- ^ Args for SGD -> Bool -- ^ Store dataset on a disk -> IO [SentL a b] -- ^ Training data 'IO' action -> IO [SentL a b] -- ^ Evaluation data -> (AVec Lb -> [(Xs, Ys)] -> [Feature]) -- ^ Feature selection -> IO (CRF a b) -- ^ Resulting model train sgdArgs onDisk trainIO evalIO extractFeats = do hSetBuffering stdout NoBuffering -- Create codec and encode the training dataset codec <- mkCodec <$> trainIO trainData_ <- encodeDataL codec <$> trainIO SGD.withData onDisk trainData_ $ \trainData -> do -- Encode the evaluation dataset evalData_ <- encodeDataL codec <$> evalIO SGD.withData onDisk evalData_ $ \evalData -> do -- A default set of labels r0 <- encodeLabels codec . S.toList . unkSet <$> trainIO -- A set of features feats <- extractFeats r0 <$> SGD.loadData trainData -- Train the model let model = (mkModel (obMax codec) (lbMax codec) feats) { r0 = r0 } para <- SGD.sgd sgdArgs (notify sgdArgs model trainData evalData) (gradOn model) trainData (values model) return $ CRF codec (model { values = para }) -- | Collect labels assigned to unknown words (with empty list -- of potential interpretations). unkSet :: Ord b => [SentL a b] -> S.Set b unkSet = S.fromList . concatMap onSent where onSent = concatMap onWord onWord word | unknown (fst word) = M.keys . unProb . snd $ word | otherwise = [] gradOn :: Model -> SGD.Para -> (Xs, Ys) -> SGD.Grad gradOn model para (xs, ys) = SGD.fromLogList $ [ (featToJustInt curr feat, L.fromPos val) | (feat, val) <- featuresIn xs ys ] ++ [ (ix, L.fromNeg val) | (FeatIx ix, val) <- expectedFeaturesIn curr xs ] where curr = model { values = para } notify :: SGD.SgdArgs -> Model -> SGD.Dataset (Xs, Ys) -- ^ Training dataset -> SGD.Dataset (Xs, Ys) -- ^ Evaluation dataset -> SGD.Para -> Int -> IO () notify SGD.SgdArgs{..} model trainData evalData para k | doneTotal k == doneTotal (k - 1) = putStr "." | SGD.size evalData > 0 = do x <- accuracy (model { values = para }) <$> SGD.loadData evalData putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = " ++ show x) | otherwise = putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = #") where doneTotal :: Int -> Int doneTotal = floor . done done :: Int -> Double done i = fromIntegral (i * batchSize) / fromIntegral trainSize trainSize = SGD.size trainData