{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RecordWildCards #-}
module Data.CRF.Chain1.Constrained.DAG.Train
(
CRF (..)
, train
, oovChosen
, anyChosen
, anyInterps
, dagProb
) where
import Control.Applicative ((<$>), (<*>))
import qualified Control.Arrow as Arr
import Control.Monad (when)
import System.IO (hSetBuffering, stdout, BufferMode (..))
import Data.Binary (Binary, put, get)
import qualified Data.Set as S
import qualified Data.Map as M
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.Foldable as F
import qualified Numeric.SGD.Momentum as SGD
import qualified Data.Number.LogFloat as LogFloat
import qualified Numeric.SGD.LogSigned as L
import qualified Data.MemoCombinators as Memo
import Data.DAG (DAG)
import qualified Data.DAG as DAG
import Data.CRF.Chain1.Constrained.Core (X, Y, Lb, AVec, Feature)
import qualified Data.CRF.Chain1.Constrained.Core as C
import qualified Data.CRF.Chain1.Constrained.Model as Md
import qualified Data.CRF.Chain1.Constrained.Dataset.Internal as Int
import qualified Data.CRF.Chain1.Constrained.DAG.Dataset.Codec as Cd
import qualified Data.CRF.Chain1.Constrained.DAG.Dataset.External as E
import Data.CRF.Chain1.Constrained.DAG.Feature (featuresIn)
import qualified Data.CRF.Chain1.Constrained.DAG.Inference as I
import qualified Data.CRF.Chain1.Constrained.DAG.Probs as P
data CRF a b = CRF {
codec :: Cd.Codec a b,
model :: Md.Model }
instance (Ord a, Ord b, Binary a, Binary b) => Binary (CRF a b) where
put CRF{..} = put codec >> put model
get = CRF <$> get <*> get
train
:: (Ord a, Ord b)
=> SGD.SgdArgs
-> Bool
-> ([E.SentL a b] -> S.Set b)
-> (AVec Lb -> [DAG () (X, Y)] -> [Feature])
-> IO [E.SentL a b]
-> IO [E.SentL a b]
-> IO (CRF a b)
train sgdArgs onDisk mkR0 featSel trainIO evalIO = do
hSetBuffering stdout NoBuffering
codec <- Cd.mkCodec <$> trainIO
trainData_ <- Cd.encodeDataL codec <$> trainIO
let trainLenOld = length trainData_
trainData0 = verifyDataset trainData_
trainLenNew = length trainData0
when (trainLenNew < trainLenOld) $ do
putStrLn $ "Discarded "
++ show (trainLenOld - trainLenNew) ++ "/" ++ show trainLenOld
++ " elements from the training dataset"
SGD.withData onDisk trainData0 $ \trainData -> do
evalData_ <- Cd.encodeDataL codec <$> evalIO
SGD.withData onDisk evalData_ $ \evalData -> do
r0 <- Cd.encodeLabels codec . S.toList . mkR0 <$> trainIO
feats <- featSel r0 <$> SGD.loadData trainData
let model = (Md.mkModel (Cd.obMax codec) (Cd.lbMax codec) feats) { Md.r0 = r0 }
para <- SGD.sgd sgdArgs
(notify sgdArgs model trainData evalData)
(gradOn model) trainData (Md.values model)
return $ CRF codec (model { Md.values = para })
gradOn :: Md.Model -> SGD.Para -> DAG a (X, Y) -> SGD.Grad
gradOn model para dag = SGD.fromLogList $
[ (Md.featToJustInt curr feat, L.fromPos val)
| (feat, val) <- featuresIn dag ] ++
[ (ix, L.fromNeg val)
| (Md.FeatIx ix, val) <- I.expectedFeaturesIn curr (fmap fst dag) ]
where
curr = model { Md.values = para }
notify
:: SGD.SgdArgs -> Md.Model
-> SGD.Dataset (DAG a (X, Y))
-> SGD.Dataset (DAG a (X, Y))
-> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} model trainData evalData para k
| doneTotal k == doneTotal (k - 1) = putStr "."
| otherwise = do
putStrLn "" >> report para
where
report para = do
let crf = model {Md.values = para}
llh <- show
. LogFloat.logFromLogFloat
. P.parLikelihood crf
<$> SGD.loadData trainData
acc <-
if SGD.size evalData > 0
then show . I.accuracy crf <$> SGD.loadData evalData
else return "#"
putStrLn $ "[" ++ show (doneTotal k) ++ "] stats:"
putStrLn $ "min(params) = " ++ show (U.minimum para)
putStrLn $ "max(params) = " ++ show (U.maximum para)
putStrLn $ "log(likelihood(train)) = " ++ llh
putStrLn $ "acc(eval) = " ++ acc
doneTotal :: Int -> Int
doneTotal = floor . done
done :: Int -> Double
done i
= fromIntegral (i * batchSize)
/ fromIntegral trainSize
trainSize = SGD.size trainData
dagProb :: DAG a (X, Y) -> Double
dagProb dag = sum
[ fromEdge edgeID
| edgeID <- DAG.dagEdges dag
, DAG.isInitialEdge edgeID dag ]
where
fromEdge =
Memo.wrap DAG.EdgeID DAG.unEdgeID Memo.integral fromEdge'
fromEdge' edgeID
= edgeProb edgeID
* fromNode (DAG.endsWith edgeID dag)
edgeProb edgeID =
let (_x, y) = DAG.edgeLabel edgeID dag
in sum . map snd $ C.unY y
fromNode nodeID =
case DAG.outgoingEdges nodeID dag of
[] -> 1
xs -> sum (map fromEdge xs)
verifyDataset :: [DAG a (X, Y)] -> [DAG a (X, Y)]
verifyDataset =
filter verify
where
verify dag =
let p = dagProb dag
in p >= 1 - eps && p <= 1 + eps
eps = 1e-9
oovChosen :: Ord b => [E.SentL a b] -> S.Set b
oovChosen =
collect onWord
where
onWord x
| E.unknown (E.word x) = M.keys . E.unProb . E.choice $ x
| otherwise = []
anyChosen :: Ord b => [E.SentL a b] -> S.Set b
anyChosen = collect $ M.keys . E.unProb . E.choice
anyInterps :: Ord b => [E.SentL a b] -> S.Set b
anyInterps = S.union
<$> collect (S.toList . E.lbs . E.word)
<*> anyChosen
collect :: Ord b => (E.WordL a b -> [b]) -> [E.SentL a b] -> S.Set b
collect onWord = S.fromList . concatMap (F.concatMap onWord)