{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RecordWildCards #-}
module Data.CRF.Chain1.Constrained.DAG.Train
(
CRF (..)
, train
, oovChosen
, anyChosen
, anyInterps
, verifyDAG
, Error(..)
) where
import Control.Applicative ((<$>), (<*>))
import Control.Monad (when, guard)
import System.IO (hSetBuffering, stdout, BufferMode (..))
import Data.Binary (Binary, put, get)
import qualified Data.Set as S
import qualified Data.Map as M
import qualified Data.Vector.Unboxed as U
import qualified Data.Foldable as F
import qualified Numeric.SGD.Momentum as SGD
import qualified Data.Number.LogFloat as LogFloat
import qualified Numeric.SGD.LogSigned as L
import qualified Data.MemoCombinators as Memo
import Data.DAG (DAG)
import qualified Data.DAG as DAG
import Data.CRF.Chain1.Constrained.Core (X, Y, Lb, AVec, Feature)
import qualified Data.CRF.Chain1.Constrained.Core as C
import qualified Data.CRF.Chain1.Constrained.Model as Md
import qualified Data.CRF.Chain1.Constrained.DAG.Dataset.Codec as Cd
import qualified Data.CRF.Chain1.Constrained.DAG.Dataset.External as E
import Data.CRF.Chain1.Constrained.DAG.Feature (featuresIn)
import qualified Data.CRF.Chain1.Constrained.DAG.Inference as I
import qualified Data.CRF.Chain1.Constrained.DAG.Probs as P
data CRF a b = CRF {
codec :: Cd.Codec a b,
model :: Md.Model }
instance (Ord a, Ord b, Binary a, Binary b) => Binary (CRF a b) where
put CRF{..} = put codec >> put model
get = CRF <$> get <*> get
train
:: (Ord a, Ord b)
=> SGD.SgdArgs
-> Bool
-> ([E.SentL a b] -> S.Set b)
-> (AVec Lb -> [DAG () (X, Y)] -> [Feature])
-> IO [E.SentL a b]
-> IO [E.SentL a b]
-> IO (CRF a b)
train sgdArgs onDisk mkR0 featSel trainIO evalIO = do
hSetBuffering stdout NoBuffering
codec <- Cd.mkCodec <$> trainIO
trainData_ <- Cd.encodeDataL codec <$> trainIO
let trainLenOld = length trainData_
trainData0 = verifyDataset trainData_
trainLenNew = length trainData0
when (trainLenNew < trainLenOld) $ do
putStrLn $ "Discarded "
++ show (trainLenOld - trainLenNew) ++ "/" ++ show trainLenOld
++ " elements from the training dataset"
SGD.withData onDisk trainData0 $ \trainData -> do
evalData_ <- Cd.encodeDataL codec <$> evalIO
SGD.withData onDisk evalData_ $ \evalData -> do
r0 <- Cd.encodeLabels codec . S.toList . mkR0 <$> trainIO
feats <- featSel r0 <$> SGD.loadData trainData
let model = (Md.mkModel (Cd.obMax codec) (Cd.lbMax codec) feats) { Md.r0 = r0 }
para <- SGD.sgd sgdArgs
(notify sgdArgs model trainData evalData)
(gradOn model) trainData (Md.values model)
return $ CRF codec (model { Md.values = para })
gradOn :: Md.Model -> SGD.Para -> DAG a (X, Y) -> SGD.Grad
gradOn model para dag = SGD.fromLogList $
[ (Md.featToJustInt curr feat, L.fromPos val)
| (feat, val) <- featuresIn dag ] ++
[ (ix, L.fromNeg val)
| (Md.FeatIx ix, val) <- I.expectedFeaturesIn curr (fmap fst dag) ]
where
curr = model { Md.values = para }
notify
:: SGD.SgdArgs -> Md.Model
-> SGD.Dataset (DAG a (X, Y))
-> SGD.Dataset (DAG a (X, Y))
-> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} model trainData evalData para k
| doneTotal k == doneTotal (k - 1) = putStr "."
| otherwise = do
putStrLn "" >> report para
where
report paraNow = do
let crf = model {Md.values = paraNow}
llh <- show
. LogFloat.logFromLogFloat
. P.parLikelihood crf
<$> SGD.loadData trainData
acc <-
if SGD.size evalData > 0
then show . I.accuracy crf <$> SGD.loadData evalData
else return "#"
putStrLn $ "[" ++ show (doneTotal k) ++ "] stats:"
putStrLn $ "min(params) = " ++ show (U.minimum para)
putStrLn $ "max(params) = " ++ show (U.maximum para)
putStrLn $ "log(likelihood(train)) = " ++ llh
putStrLn $ "acc(eval) = " ++ acc
doneTotal :: Int -> Int
doneTotal = floor . done
done :: Int -> Double
done i
= fromIntegral (i * batchSize)
/ fromIntegral trainSize
trainSize = SGD.size trainData
verifyDataset :: [DAG a (X, Y)] -> [DAG a (X, Y)]
verifyDataset =
filter verify
where
verify dag = verifyDAG dag == Nothing
data Error
= Malformed
| Cyclic
| SeveralSources [DAG.NodeID]
| SeveralTargets [DAG.NodeID]
| WrongBalance [DAG.NodeID]
deriving (Show, Eq, Ord)
verifyDAG :: DAG a (X, Y) -> Maybe Error
verifyDAG dag
| not (DAG.isOK dag) = Just Malformed
| not (DAG.isDAG dag) = Just Cyclic
| length sources /= 1 = Just $ SeveralSources sources
| length targets /= 1 = Just $ SeveralTargets targets
| length wrong > 1 = Just $ WrongBalance wrong
| otherwise = Nothing
where
sources = do
node <- DAG.dagNodes dag
guard . null $ DAG.ingoingEdges node dag
return node
targets = do
node <- DAG.dagNodes dag
guard . null $ DAG.outgoingEdges node dag
return node
wrong = do
node <- DAG.dagNodes dag
let ing = DAG.ingoingEdges node dag
out = DAG.outgoingEdges node dag
ingBalance =
if node `elem` sources
then 1
else sum (map edgeProb ing)
outBalance =
if node `elem` targets
then 1
else sum (map edgeProb out)
guard . not $ equal ingBalance outBalance
return node
edgeProb edgeID =
let (_x, y) = DAG.edgeLabel edgeID dag
in sum . map snd $ C.unY y
equal x y =
x - eps <= y && x + eps >= y
eps = 1e-9
oovChosen :: Ord b => [E.SentL a b] -> S.Set b
oovChosen =
collect onWord
where
onWord x
| E.unknown (E.word x) = M.keys . E.unProb . E.choice $ x
| otherwise = []
anyChosen :: Ord b => [E.SentL a b] -> S.Set b
anyChosen = collect $ M.keys . E.unProb . E.choice
anyInterps :: Ord b => [E.SentL a b] -> S.Set b
anyInterps = S.union
<$> collect (S.toList . E.lbs . E.word)
<*> anyChosen
collect :: Ord b => (E.WordL a b -> [b]) -> [E.SentL a b] -> S.Set b
collect onWord = S.fromList . concatMap (F.concatMap onWord)