module Data.CRF.Chain1.Constrained.Inference
( tag
, tagK
, marginals
, accuracy
, expectedFeaturesIn
, zx
, zx'
) where
import Control.Applicative ((<$>))
import Data.Maybe (catMaybes)
import Data.List (maximumBy, sortBy)
import Data.Function (on)
import qualified Data.Array as A
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import Control.Parallel.Strategies (rseq, parMap)
import Control.Parallel (par, pseq)
import GHC.Conc (numCapabilities)
import qualified Data.Number.LogFloat as L
import qualified Data.CRF.Chain1.Constrained.DP as DP
import Data.CRF.Chain1.Constrained.Util (partition)
import Data.CRF.Chain1.Constrained.Dataset.Internal
import Data.CRF.Chain1.Constrained.Feature (Feature(..))
import Data.CRF.Chain1.Constrained.Model
import Data.CRF.Chain1.Constrained.Intersect
type LbIx = Int
type ProbArray = Int -> LbIx -> L.LogFloat
lbVec :: Model -> Xs -> Int -> AVec Lb
lbVec crf xs k = case xs V.! k of
X _ -> (r0 crf)
R _ r -> r
lbNum :: Model -> Xs -> Int -> Int
lbNum crf xs = (U.length . unAVec) . lbVec crf xs
lbOn :: Model -> X -> Int -> Lb
lbOn crf (X _) = (unAVec (r0 crf) U.!)
lbOn _ (R _ r) = (unAVec r U.!)
lbIxs :: Model -> Xs -> Int -> [(Int, Lb)]
lbIxs crf xs = zip [0..] . U.toList . unAVec . lbVec crf xs
computePsi :: Model -> Xs -> Int -> LbIx -> L.LogFloat
computePsi crf xs i = (A.!) $ A.accumArray (*) 1 bounds
[ (k, valueL crf ix)
| ob <- unX (xs V.! i)
, (k, ix) <- intersect (obIxs crf ob) (lbVec crf xs i) ]
where
bounds = (0, lbNum crf xs i 1)
forward :: Model -> Xs -> ProbArray
forward crf xs = alpha where
alpha = DP.flexible2 (0, V.length xs) bounds
(\t i -> withMem (computePsi crf xs i) t i)
bounds i
| i == V.length xs = (0, 0)
| otherwise = (0, lbNum crf xs i 1)
withMem psi alpha i
| i == V.length xs = const u
| i == 0 = \j ->
let x = lbOn crf (xs V.! i) j
in psi j * sgValue crf x
| otherwise = \j ->
let x = lbOn crf (xs V.! i) j
in psi j * ((u v x) + w x)
where
u = sum
[ alpha (i1) k
| (k, _) <- lbIxs crf xs (i1) ]
v x = sum
[ alpha (i1) k
| (k, _) <- intersect (prevIxs crf x) (lbVec crf xs (i1)) ]
w x = sum
[ alpha (i1) k * valueL crf ix
| (k, ix) <- intersect (prevIxs crf x) (lbVec crf xs (i1)) ]
backward :: Model -> Xs -> ProbArray
backward crf xs = beta where
beta = DP.flexible2 (0, V.length xs) bounds
(\t i -> withMem (computePsi crf xs i) t i)
bounds i
| i == 0 = (0, 0)
| otherwise = (0, lbNum crf xs (i1) 1)
withMem psi beta i
| i == V.length xs = const 0
| i == 0 = const $ sum
[ beta (i+1) k * psi k
* sgValue crf (lbOn crf (xs V.! i) k)
| (k, _) <- lbIxs crf xs i ]
| otherwise = \j ->
let y = lbOn crf (xs V.! (i1)) j
in (u v y) + w y
where
u = sum
[ beta (i+1) k * psi k
| (k, _ ) <- lbIxs crf xs i ]
v y = sum
[ beta (i+1) k * psi k
| (k, _ ) <- intersect (nextIxs crf y) (lbVec crf xs i) ]
w y = sum
[ beta (i+1) k * psi k * valueL crf ix
| (k, ix) <- intersect (nextIxs crf y) (lbVec crf xs i) ]
zxBeta :: ProbArray -> L.LogFloat
zxBeta beta = beta 0 0
zxAlpha :: Xs -> ProbArray -> L.LogFloat
zxAlpha xs alpha = alpha (V.length xs) 0
zx :: Model -> Xs -> L.LogFloat
zx crf = zxBeta . backward crf
zx' :: Model -> Xs -> L.LogFloat
zx' crf sent = zxAlpha sent (forward crf sent)
marginals :: Model -> Xs -> [[(Lb, L.LogFloat)]]
marginals crf xs =
let alpha = forward crf xs
beta = backward crf xs
in [ [ (x, prob1 alpha beta i k)
| (k, x) <- lbIxs crf xs i ]
| i <- [0 .. V.length xs 1] ]
tagK :: Int -> Model -> Xs -> [[(Lb, L.LogFloat)]]
tagK k crf xs = map
( take k
. reverse
. sortBy (compare `on` snd)
) (marginals crf xs)
tag :: Model -> Xs -> [Lb]
tag crf = map (fst . head) . (tagK 1 crf)
prob1 :: ProbArray -> ProbArray -> Int -> LbIx -> L.LogFloat
prob1 alpha beta k x =
alpha k x * beta (k + 1) x / zxBeta beta
prob2 :: Model -> ProbArray -> ProbArray -> Int -> (LbIx -> L.LogFloat)
-> LbIx -> LbIx -> FeatIx -> L.LogFloat
prob2 crf alpha beta k psi x y ix
= alpha (k 1) y * beta (k + 1) x
* psi x * valueL crf ix / zxBeta beta
goodAndBad :: Model -> Xs -> Ys -> (Int, Int)
goodAndBad crf xs ys =
foldl gather (0, 0) $ zip labels labels'
where
labels = [ (best . unY) (ys V.! i)
| i <- [0 .. V.length ys 1] ]
best zs
| null zs = Nothing
| otherwise = Just . fst $ maximumBy (compare `on` snd) zs
labels' = map Just $ tag crf xs
gather (good, bad) (x, y)
| x == y = (good + 1, bad)
| otherwise = (good, bad + 1)
goodAndBad' :: Model -> [(Xs, Ys)] -> (Int, Int)
goodAndBad' crf dataset =
let add (g, b) (g', b') = (g + g', b + b')
in foldl add (0, 0) [goodAndBad crf x y | (x, y) <- dataset]
accuracy :: Model -> [(Xs, Ys)] -> Double
accuracy crf dataset =
let k = numCapabilities
parts = partition k dataset
xs = parMap rseq (goodAndBad' crf) parts
(good, bad) = foldl add (0, 0) xs
add (g, b) (g', b') = (g + g', b + b')
in fromIntegral good / fromIntegral (good + bad)
expectedFeaturesOn
:: Model -> ProbArray -> ProbArray -> Xs
-> Int -> [(FeatIx, L.LogFloat)]
expectedFeaturesOn crf alpha beta xs i =
tFeats ++ oFeats
where
psi = computePsi crf xs i
pr1 = prob1 alpha beta i
pr2 = prob2 crf alpha beta i psi
oFeats = [ (ix, pr1 k)
| o <- unX (xs V.! i)
, (k, ix) <- intersect (obIxs crf o) (lbVec crf xs i) ]
tFeats
| i == 0 = catMaybes
[ (, pr1 k) <$> featToIx crf (SFeature x)
| (k, x) <- lbIxs crf xs i ]
| otherwise =
[ (ix, pr2 k l ix)
| (k, x) <- lbIxs crf xs i
, (l, ix) <- intersect (prevIxs crf x) (lbVec crf xs (i1)) ]
expectedFeaturesIn :: Model -> Xs -> [(FeatIx, L.LogFloat)]
expectedFeaturesIn crf xs = zxF `par` zxB `pseq` zxF `pseq`
concat [expectedOn k | k <- [0 .. V.length xs 1] ]
where
expectedOn = expectedFeaturesOn crf alpha beta xs
alpha = forward crf xs
beta = backward crf xs
zxF = zxAlpha xs alpha
zxB = zxBeta beta