{-# LANGUAGE FlexibleContexts #-}
module Data.CRF.Chain1.Inference
( tag
, marginals
, accuracy
, expectedFeaturesIn
, zx
, zx'
) where
import Control.Applicative ((<$>), (<*>), pure)
import Data.Maybe (catMaybes)
import Data.List (maximumBy)
import Data.Function (on)
import qualified Data.Array as A
import qualified Data.Vector as V
import Control.Parallel.Strategies (rseq, parMap)
import Control.Parallel (par, pseq)
import GHC.Conc (numCapabilities)
import qualified Data.Number.LogFloat as L
import qualified Data.CRF.Chain1.DP as DP
import Data.CRF.Chain1.Util (partition)
import Data.CRF.Chain1.Dataset.Internal
import Data.CRF.Chain1.Model
type ProbArray = Int -> Lb -> L.LogFloat
type AccF = [L.LogFloat] -> L.LogFloat
computePsi :: Model -> Xs -> Int -> Lb -> L.LogFloat
computePsi crf xs i = (A.!) $ A.accumArray (*) 1 bounds
[ (x, valueL crf ix)
| ob <- unX (xs V.! i)
, (x, ix) <- obIxs crf ob ]
where
bounds = (Lb 0, Lb $ lbNum crf - 1)
forward :: AccF -> Model -> Xs -> ProbArray
forward acc crf sent = DP.flexible2
(0, V.length sent) wordBounds
(\t k -> withMem (computePsi crf sent k) t k)
where
wordBounds k
| k == V.length sent = (Lb 0, Lb 0)
| otherwise = (Lb 0, Lb $ lbNum crf - 1)
withMem psi alpha k x
| k == 0 = psi x * sgValue crf x
| k == V.length sent = acc
[ alpha (k - 1) y
| y <- lbSet crf ]
| otherwise = acc
[ alpha (k - 1) y * psi x * valueL crf ix
| (y, ix) <- prevIxs crf x ]
backward :: AccF -> Model -> Xs -> ProbArray
backward acc crf sent = DP.flexible2
(0, V.length sent) wordBounds
(\t k -> withMem (computePsi crf sent k) t k)
where
wordBounds k
| k == 0 = (Lb 0, Lb 0)
| otherwise = (Lb 0, Lb $ lbNum crf - 1)
withMem psi beta k y
| k == V.length sent = 1
| k == 0 = acc
[ beta (k + 1) x * psi x * valueL crf ix
| (x, ix) <- sgIxs crf ]
| otherwise = acc
[ beta (k + 1) x * psi x * valueL crf ix
| (x, ix) <- nextIxs crf y ]
zxBeta :: ProbArray -> L.LogFloat
zxBeta beta = beta 0 0
zxAlpha :: Xs -> ProbArray -> L.LogFloat
zxAlpha sent alpha = alpha (V.length sent) 0
zx :: Model -> Xs -> L.LogFloat
zx crf = zxBeta . backward sum crf
zx' :: Model -> Xs -> L.LogFloat
zx' crf sent = zxAlpha sent (forward sum crf sent)
argmax :: Ord b => (a -> Maybe b) -> [a] -> Maybe (a, b)
argmax _ [] = Nothing
argmax f xs
| null ys = Nothing
| otherwise = Just $ foldl1 choice ys
where
ys = catMaybes $ map (\x -> (,) <$> pure x <*> f x) xs
choice (x1, v1) (x2, v2)
| v1 > v2 = (x1, v1)
| otherwise = (x2, v2)
tag :: Model -> Xs -> [Lb]
tag crf sent = collectMaxArg (0, 0) [] $ DP.flexible2
(0, V.length sent) wordBounds
(\t k -> withMem (computePsi crf sent k) t k)
where
n = V.length sent
wordBounds k
| k == 0 = (Lb 0, Lb 0)
| otherwise = (Lb 0, Lb $ lbNum crf - 1)
withMem psi mem k y
| k == n = Just (-1, 1)
| k == 0 = prune <$> argmax eval (sgIxs crf)
| otherwise = prune <$> argmax eval (nextIxs crf y)
where
eval (x, ix) = do
v <- snd <$> mem (k + 1) x
return $ v * psi x * valueL crf ix
prune ((x, _ix), v) = (x, v)
collectMaxArg (i, j) acc mem
| i < n = collect (mem i j)
| otherwise = reverse acc
where
collect (Just (h, _)) = collectMaxArg (i + 1, h) (h:acc) mem
collect Nothing = error "tag.collect: Nothing"
marginals :: Model -> Xs -> [[(Lb, L.LogFloat)]]
marginals crf sent =
let alpha = forward sum crf sent
beta = backward sum crf sent
in [ [ (x, prob1 alpha beta k x)
| x <- lbSet crf ]
| k <- [0 .. V.length sent - 1] ]
goodAndBad :: Model -> Xs -> Ys -> (Int, Int)
goodAndBad crf sent labels =
foldl gather (0, 0) (zip labels' labels'')
where
labels' = [ fst . maximumBy (compare `on` snd) $ unY (labels V.! i)
| i <- [0 .. V.length labels - 1] ]
labels'' = tag crf sent
gather (good, bad) (x, y)
| x == y = (good + 1, bad)
| otherwise = (good, bad + 1)
goodAndBad' :: Model -> [(Xs, Ys)] -> (Int, Int)
goodAndBad' crf dataset =
let add (g, b) (g', b') = (g + g', b + b')
in foldl add (0, 0) [goodAndBad crf x y | (x, y) <- dataset]
accuracy :: Model -> [(Xs, Ys)] -> Double
accuracy crf dataset =
let k = numCapabilities
parts = partition k dataset
xs = parMap rseq (goodAndBad' crf) parts
(good, bad) = foldl add (0, 0) xs
add (g, b) (g', b') = (g + g', b + b')
in fromIntegral good / fromIntegral (good + bad)
prob2 :: Model -> ProbArray -> ProbArray -> Int -> (Lb -> L.LogFloat)
-> Lb -> Lb -> FeatIx -> L.LogFloat
prob2 crf alpha beta k psi x y ix
= alpha (k - 1) y * beta (k + 1) x
* psi x * valueL crf ix / zxBeta beta
prob1 :: ProbArray -> ProbArray -> Int -> Lb -> L.LogFloat
prob1 alpha beta k x =
alpha k x * beta (k + 1) x / zxBeta beta
expectedFeaturesOn
:: Model -> ProbArray -> ProbArray -> Xs
-> Int -> [(FeatIx, L.LogFloat)]
expectedFeaturesOn crf alpha beta sent k =
tFeats ++ oFeats
where
psi = computePsi crf sent k
pr1 = prob1 alpha beta k
pr2 = prob2 crf alpha beta k psi
oFeats = [ (ix, pr1 x)
| o <- unX (sent V.! k)
, (x, ix) <- obIxs crf o ]
tFeats
| k == 0 =
[ (ix, pr1 x)
| (x, ix) <- sgIxs crf ]
| otherwise =
[ (ix, pr2 x y ix)
| x <- lbSet crf
, (y, ix) <- prevIxs crf x ]
expectedFeaturesIn :: Model -> Xs -> [(FeatIx, L.LogFloat)]
expectedFeaturesIn crf sent = zxF `par` zxB `pseq` zxF `pseq`
concat [expectedOn k | k <- [0 .. V.length sent - 1] ]
where
expectedOn = expectedFeaturesOn crf alpha beta sent
alpha = forward sum crf sent
beta = backward sum crf sent
zxF = zxAlpha sent alpha
zxB = zxBeta beta