{-# LANGUAGE FlexibleContexts #-} -- Inference with CRFs. module Data.CRF.Chain1.Inference ( tag , accuracy , expectedFeaturesIn , zx , zx' ) where import Data.List (maximumBy) import Data.Function (on) import qualified Data.Array as A import qualified Data.Vector as V import Control.Parallel.Strategies (rseq, parMap) import Control.Parallel (par, pseq) import GHC.Conc (numCapabilities) import qualified Data.Number.LogFloat as L import qualified Data.CRF.Chain1.DP as DP import Data.CRF.Chain1.Util (partition) import Data.CRF.Chain1.Dataset.Internal import Data.CRF.Chain1.Model type ProbArray = Int -> Lb -> L.LogFloat type AccF = [L.LogFloat] -> L.LogFloat -- | Compute the table of potential products associated with -- observation features for the given sentence position. computePsi :: Model -> Xs -> Int -> Lb -> L.LogFloat computePsi crf xs i = (A.!) $ A.accumArray (*) 1 bounds [ (x, valueL crf ix) | ob <- unX (xs V.! i) , (x, ix) <- obIxs crf ob ] where bounds = (Lb 0, Lb $ lbNum crf - 1) -- | Forward table computation. forward :: AccF -> Model -> Xs -> ProbArray forward acc crf sent = DP.flexible2 (0, V.length sent) wordBounds (\t k -> withMem (computePsi crf sent k) t k) where wordBounds k | k == V.length sent = (Lb 0, Lb 0) | otherwise = (Lb 0, Lb $ lbNum crf - 1) -- | Forward table equation, where k is current position, x is a label -- on current position and psi is a psi table computed for current -- position. -- FIXME: null sentence? withMem psi alpha k x | k == 0 = psi x * sgValue crf x | k == V.length sent = acc [ alpha (k - 1) y | y <- lbSet crf ] | otherwise = acc [ alpha (k - 1) y * psi x * valueL crf ix | (y, ix) <- prevIxs crf x ] -- | Backward table computation. backward :: AccF -> Model -> Xs -> ProbArray backward acc crf sent = DP.flexible2 (0, V.length sent) wordBounds (\t k -> withMem (computePsi crf sent k) t k) where wordBounds k | k == 0 = (Lb 0, Lb 0) | otherwise = (Lb 0, Lb $ lbNum crf - 1) -- | Backward table equation, where k is current position, y is a label -- on previous, k-1, position and psi is a psi table computed for current -- position. withMem psi beta k y | k == V.length sent = 1 | k == 0 = acc [ beta (k + 1) x * psi x * valueL crf ix | (x, ix) <- sgIxs crf ] | otherwise = acc [ beta (k + 1) x * psi x * valueL crf ix | (x, ix) <- nextIxs crf y ] zxBeta :: ProbArray -> L.LogFloat zxBeta beta = beta 0 0 zxAlpha :: Xs -> ProbArray -> L.LogFloat zxAlpha sent alpha = alpha (V.length sent) 0 -- | Normalization factor computed for the 'Xs' sentence using the -- backward computation. zx :: Model -> Xs -> L.LogFloat zx crf = zxBeta . backward sum crf -- | Normalization factor computed for the 'Xs' sentence using the -- forward computation. zx' :: Model -> Xs -> L.LogFloat zx' crf sent = zxAlpha sent (forward sum crf sent) -------------------------------------------------------------- argmax :: Ord b => (a -> b) -> [a] -> (a, b) argmax _ [] = error "argmax: null list" argmax f xs = foldl1 choice $ map (\x -> (x, f x)) xs where choice (x1, v1) (x2, v2) | v1 > v2 = (x1, v1) | otherwise = (x2, v2) -- | Determine the most probable label sequence given the context of the -- CRF model and the sentence. tag :: Model -> Xs -> [Lb] tag crf sent = collectMaxArg (0, 0) [] $ DP.flexible2 (0, V.length sent) wordBounds (\t k -> withMem (computePsi crf sent k) t k) where wordBounds k | k == 0 = (Lb 0, Lb 0) | otherwise = (Lb 0, Lb $ lbNum crf - 1) withMem psi mem k y | k == V.length sent = (-1, 1) | k == 0 = prune . argmax eval $ sgIxs crf | otherwise = prune . argmax eval $ nextIxs crf y where eval (x, ix) = (snd $ mem (k + 1) x) * psi x * valueL crf ix prune ((x, _ix), v) = (x, v) collectMaxArg (i, j) acc mem = collect (mem i j) where collect (h, _) | h == -1 = reverse acc | otherwise = collectMaxArg (i + 1, h) (h:acc) mem -- tagProbs :: Sent s => Model -> s -> [[Double]] -- tagProbs crf sent = -- let alpha = forward maximum crf sent -- beta = backward maximum crf sent -- normalize vs = -- let d = - logSum vs -- in map (+d) vs -- m1 k x = alpha k x + beta (k + 1) x -- in [ map exp $ normalize [m1 i k | k <- interpIxs sent i] -- | i <- [0 .. V.length sent - 1] ] -- -- -- tag probabilities with respect to -- -- marginal distributions -- tagProbs' :: Sent s => Model -> s -> [[Double]] -- tagProbs' crf sent = -- let alpha = forward logSum crf sent -- beta = backward logSum crf sent -- in [ [ exp $ prob1 crf alpha beta sent i k -- | k <- interpIxs sent i ] -- | i <- [0 .. V.length sent - 1] ] goodAndBad :: Model -> Xs -> Ys -> (Int, Int) goodAndBad crf sent labels = foldl gather (0, 0) (zip labels' labels'') where labels' = [ fst . maximumBy (compare `on` snd) $ unY (labels V.! i) | i <- [0 .. V.length labels - 1] ] labels'' = tag crf sent gather (good, bad) (x, y) | x == y = (good + 1, bad) | otherwise = (good, bad + 1) goodAndBad' :: Model -> [(Xs, Ys)] -> (Int, Int) goodAndBad' crf dataset = let add (g, b) (g', b') = (g + g', b + b') in foldl add (0, 0) [goodAndBad crf x y | (x, y) <- dataset] -- | Compute the accuracy of the model with respect to the labeled dataset. accuracy :: Model -> [(Xs, Ys)] -> Double accuracy crf dataset = let k = numCapabilities parts = partition k dataset xs = parMap rseq (goodAndBad' crf) parts (good, bad) = foldl add (0, 0) xs add (g, b) (g', b') = (g + g', b + b') in fromIntegral good / fromIntegral (good + bad) -------------------------------------------------------------- -- prob :: L.Vect t Int => Model -> Sent Int t -> Double -- prob crf sent = -- sum [ phiOn crf sent k -- | k <- [0 .. (length sent) - 1] ] -- - zx' crf sent -- -- -- TODO: Wziac pod uwage "Regularization Variance" ! -- cll :: Model -> [Sentence] -> Double -- cll crf dataset = sum [prob crf sent | sent <- dataset] -- prob2 :: SentR s => Model -> ProbArray -> ProbArray -> s -- -> Int -> Lb -> Lb -> Double -- prob2 crf alpha beta sent k x y -- = alpha (k - 1) y + beta (k + 1) x -- + phi crf (observationsOn sent k) a b -- - zxBeta beta -- where -- a = interp sent k x -- b = interp sent (k - 1) y prob2 :: Model -> ProbArray -> ProbArray -> Int -> (Lb -> L.LogFloat) -> Lb -> Lb -> FeatIx -> L.LogFloat prob2 crf alpha beta k psi x y ix = alpha (k - 1) y * beta (k + 1) x * psi x * valueL crf ix / zxBeta beta -- prob1 :: SentR s => Model -> ProbArray -> ProbArray -- -> s -> Int -> Label -> Double -- prob1 crf alpha beta sent k x = logSum -- [ prob2 crf alpha beta sent k x y -- | y <- interpIxs sent (k - 1) ] prob1 :: ProbArray -> ProbArray -> Int -> Lb -> L.LogFloat prob1 alpha beta k x = alpha k x * beta (k + 1) x / zxBeta beta expectedFeaturesOn :: Model -> ProbArray -> ProbArray -> Xs -> Int -> [(FeatIx, L.LogFloat)] expectedFeaturesOn crf alpha beta sent k = tFeats ++ oFeats where psi = computePsi crf sent k pr1 = prob1 alpha beta k pr2 = prob2 crf alpha beta k psi oFeats = [ (ix, pr1 x) | o <- unX (sent V.! k) , (x, ix) <- obIxs crf o ] tFeats | k == 0 = [ (ix, pr1 x) | (x, ix) <- sgIxs crf ] | otherwise = [ (ix, pr2 x y ix) | x <- lbSet crf , (y, ix) <- prevIxs crf x ] -- | A list of features (represented by feature indices) defined within -- the context of the sentence accompanied by expected probabilities -- determined on the basis of the model. -- -- One feature can occur multiple times in the output list. expectedFeaturesIn :: Model -> Xs -> [(FeatIx, L.LogFloat)] expectedFeaturesIn crf sent = zxF `par` zxB `pseq` zxF `pseq` concat [expectedOn k | k <- [0 .. V.length sent - 1] ] where expectedOn = expectedFeaturesOn crf alpha beta sent alpha = forward sum crf sent beta = backward sum crf sent zxF = zxAlpha sent alpha zxB = zxBeta beta