{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}
module Data.CRF.Chain1.Constrained.DAG.Inference
( tag
, tagK
, marginals
, accuracy
, expectedFeaturesIn
, zx
, zx'
, computePsi
) where
import Control.Applicative ((<$>))
import Data.Maybe (catMaybes)
import Data.List (maximumBy, sortBy)
import Data.Function (on)
import qualified Data.Set as S
import qualified Data.Array as A
import qualified Data.Vector.Unboxed as U
import qualified Data.Foldable as F
import Control.Parallel.Strategies (rseq, parMap)
import Control.Parallel (par, pseq)
import GHC.Conc (numCapabilities)
import qualified Data.Number.LogFloat as L
import Data.DAG (EdgeID, DAG)
import qualified Data.DAG as DAG
import qualified Data.CRF.Chain1.Constrained.DP as DP
import Data.CRF.Chain1.Constrained.Util (partition)
import qualified Data.CRF.Chain1.Constrained.Model as Md
import Data.CRF.Chain1.Constrained.Core (X, Y, Lb, AVec)
import qualified Data.CRF.Chain1.Constrained.Core as C
import qualified Data.CRF.Chain1.Constrained.Intersect as I
import Debug.Trace (trace)
type LbIx = Int
type ProbArray = EdgeID -> LbIx -> L.LogFloat
safeSum :: [L.LogFloat] -> L.LogFloat
safeSum [] = 0
safeSum xs = L.sum xs
{-#INLINE safeSum #-}
lbVec :: Md.Model -> DAG a X -> EdgeID -> AVec Lb
lbVec crf dag edgeID = case DAG.edgeLabel edgeID dag of
C.X _ -> (Md.r0 crf)
C.R _ r -> r
{-# INLINE lbVec #-}
lbNum :: Md.Model -> DAG a X -> EdgeID -> Int
lbNum crf dag = (U.length . C.unAVec) . lbVec crf dag
{-# INLINE lbNum #-}
lbOn :: Md.Model -> X -> LbIx -> Lb
lbOn crf (C.X _) = (C.unAVec (Md.r0 crf) U.!)
lbOn _ (C.R _ r) = (C.unAVec r U.!)
{-# INLINE lbOn #-}
lbIxs :: Md.Model -> DAG a X -> EdgeID -> [(LbIx, Lb)]
lbIxs crf dag = zip [0..] . U.toList . C.unAVec . lbVec crf dag
{-# INLINE lbIxs #-}
computePsi :: Md.Model -> DAG a X -> EdgeID -> LbIx -> L.LogFloat
computePsi crf dag i = (A.!) $ A.accumArray (*) 1 bounds
[ (k, Md.valueL crf ix)
| ob <- C.unX (DAG.edgeLabel i dag)
, (k, ix) <- I.intersect (Md.obIxs crf ob) (lbVec crf dag i) ]
where
bounds = (0, lbNum crf dag i - 1)
computePsi'
:: Md.Model -> DAG a X
-> EdgeID -> LbIx
-> L.LogFloat
computePsi' crf dag =
(array A.!)
where
bounds = (DAG.minEdge dag, DAG.maxEdge dag)
array = A.array bounds
[ (i, computePsi crf dag i)
| i <- A.range bounds ]
forward :: Md.Model -> DAG a X -> ProbArray
forward crf dag = alpha where
alpha = DP.flexible2 bounds boundsOn
(\t i -> withMem (computePsi crf dag i) t i)
bounds = (DAG.minEdge dag, DAG.maxEdge dag + 1)
boundsOn i
| i == snd bounds = (0, 0)
| otherwise = (0, lbNum crf dag i - 1)
initialSet = S.fromList
[ i
| i <- DAG.dagEdges dag
, DAG.isInitialEdge i dag ]
withMem psi alpha' i
| i == snd bounds = const u'
| i `S.member` initialSet = \j ->
let x = lbOn crf (DAG.edgeLabel i dag) j
in psi j * Md.sgValue crf x
| otherwise = \j ->
let x = lbOn crf (DAG.edgeLabel i dag) j
in psi j * ((u - v x) + w x)
where
u = safeSum
[ alpha' iMinus1 k
| iMinus1 <- DAG.prevEdges i dag
, (k, _) <- lbIxs crf dag iMinus1 ]
v x = safeSum
[ alpha' iMinus1 k
| iMinus1 <- DAG.prevEdges i dag
, (k, _) <- I.intersect (Md.prevIxs crf x) (lbVec crf dag iMinus1) ]
w x = safeSum
[ alpha' iMinus1 k * Md.valueL crf ix
| iMinus1 <- DAG.prevEdges i dag
, (k, ix) <- I.intersect (Md.prevIxs crf x) (lbVec crf dag iMinus1) ]
u' = safeSum
[ alpha' iMinus1 k
| iMinus1 <- DAG.dagEdges dag
, DAG.isFinalEdge iMinus1 dag
, (k, _) <- lbIxs crf dag iMinus1 ]
backward :: Md.Model -> DAG a X -> ProbArray
backward crf dag = beta where
beta = DP.flexible2 bounds boundsOn withMem
bounds = (DAG.minEdge dag - 1, DAG.maxEdge dag)
boundsOn i
| i == fst bounds = (0, 0)
| otherwise = (0, lbNum crf dag i - 1)
psi = computePsi' crf dag
finalSet = S.fromList
[ i
| i <- DAG.dagEdges dag
, DAG.isFinalEdge i dag ]
withMem beta' i
| i `S.member` finalSet = const 1
| i == fst bounds = const $ safeSum
[ beta' iPlus1 k * psi iPlus1 k
* Md.sgValue crf (lbOn crf (DAG.edgeLabel iPlus1 dag) k)
| iPlus1 <- DAG.dagEdges dag
, DAG.isInitialEdge iPlus1 dag
, (k, _) <- lbIxs crf dag iPlus1 ]
| otherwise = \j ->
let y = lbOn crf (DAG.edgeLabel i dag) j
in (u - v y) + w y
where
u = safeSum
[ beta' iPlus1 k * psi iPlus1 k
| iPlus1 <- DAG.nextEdges i dag
, (k, _ ) <- lbIxs crf dag iPlus1 ]
v y = safeSum
[ beta' iPlus1 k * psi iPlus1 k
| iPlus1 <- DAG.nextEdges i dag
, (k, _ ) <- I.intersect (Md.nextIxs crf y) (lbVec crf dag iPlus1) ]
w y = safeSum
[ beta' iPlus1 k * psi iPlus1 k * Md.valueL crf ix
| iPlus1 <- DAG.nextEdges i dag
, (k, ix) <- I.intersect (Md.nextIxs crf y) (lbVec crf dag iPlus1) ]
zx' :: Md.Model -> DAG a X -> L.LogFloat
zx' crf dag = zxAlpha dag (forward crf dag)
zxAlpha :: DAG a b -> ProbArray -> L.LogFloat
zxAlpha dag alpha = alpha (DAG.maxEdge dag + 1) 0
zx :: Md.Model -> DAG a X -> L.LogFloat
zx crf dag = zxBeta dag (backward crf dag)
zxBeta :: DAG a b -> ProbArray -> L.LogFloat
zxBeta dag beta = beta (DAG.minEdge dag - 1) 0
edgeProb1
:: DAG a b
-> ProbArray
-> ProbArray
-> EdgeID
-> LbIx
-> L.LogFloat
edgeProb1 dag alpha beta k x
| any isInf [up1, up2, down] =
error $ "edgeProb1: infinite -- " ++ show [up1, up2, down, down']
| otherwise = up1 * up2 / down
where
isInf v = isInfinite (L.logFromLogFloat v :: Double)
up1 = alpha k x
up2 = beta k x
down = zxBeta dag beta
down' = zxAlpha dag alpha
{-# INLINE edgeProb1 #-}
edgeProb2
:: Md.Model
-> DAG a b
-> ProbArray
-> ProbArray
-> (EdgeID -> LbIx -> L.LogFloat)
-> (EdgeID, LbIx)
-> (EdgeID, LbIx)
-> Md.FeatIx
-> L.LogFloat
edgeProb2 crf dag alpha beta psi (kEdgeID, xLbIx) (lEdgeID, yLbIx) ix
| any isInf [up1, up2, up3, up4, down] =
error $ "edgeProb2: infinite -- " ++ show [up1, up2, up3, up4, down, down']
| otherwise = up1 * up2 * up3 * up4 / down
where
isInf x = isInfinite (L.logFromLogFloat x :: Double)
up1 = alpha kEdgeID xLbIx
up2 = beta lEdgeID yLbIx
up3 = psi lEdgeID yLbIx
up4 = Md.valueL crf ix
down = zxBeta dag beta
down' = zxAlpha dag alpha
{-# INLINE edgeProb2 #-}
marginals :: Md.Model -> DAG a X -> DAG a [(Lb, L.LogFloat)]
marginals crf dag
| not (zx1 `almostEq` zx2) = trace warning margs
| otherwise = margs
where
margs = DAG.mapE label dag
warning =
"[marginals] normalization factors differ significantly: "
++ show (L.logFromLogFloat zx1, L.logFromLogFloat zx2)
label edgeID _ =
[ (lab, prob1 edgeID labID)
| (labID, lab) <- lbIxs crf dag edgeID ]
prob1 = edgeProb1 dag alpha beta
alpha = forward crf dag
beta = backward crf dag
zx1 = zxAlpha dag alpha
zx2 = zxBeta dag beta
tagK :: Int -> Md.Model -> DAG a X -> DAG a [(Lb, L.LogFloat)]
tagK k crf dag = fmap
( take k
. reverse
. sortBy (compare `on` snd)
) (marginals crf dag)
tag :: Md.Model -> DAG a X -> DAG a Lb
tag crf = fmap (fst . head) . (tagK 1 crf)
expectedFeaturesOn
:: Md.Model
-> DAG a X
-> ProbArray
-> ProbArray
-> EdgeID
-> [(Md.FeatIx, L.LogFloat)]
expectedFeaturesOn crf dag alpha beta iEdgeID =
tFeats ++ oFeats
where
prob1 = edgeProb1 dag alpha beta iEdgeID
oFeats = [ (ix, prob1 k)
| ob <- C.unX (DAG.edgeLabel iEdgeID dag)
, (k, ix) <- I.intersect (Md.obIxs crf ob) (lbVec crf dag iEdgeID) ]
psi = computePsi' crf dag
prob2 = edgeProb2 crf dag alpha beta psi
tFeats
| DAG.isInitialEdge iEdgeID dag = catMaybes
[ (, prob1 k) <$> Md.featToIx crf (C.SFeature x)
| (k, x) <- lbIxs crf dag iEdgeID ]
| otherwise =
[ (ix, prob2 (iMinus1, l) (iEdgeID, k) ix)
| (k, x) <- lbIxs crf dag iEdgeID
, iMinus1 <- DAG.prevEdges iEdgeID dag
, (l, ix) <- I.intersect (Md.prevIxs crf x) (lbVec crf dag iMinus1) ]
expectedFeaturesIn
:: Md.Model
-> DAG a X
-> [(Md.FeatIx, L.LogFloat)]
expectedFeaturesIn crf dag = zxF `par` zxB `pseq` zxF `pseq`
concat [expectedOn edgeID | edgeID <- DAG.dagEdges dag]
where
expectedOn = expectedFeaturesOn crf dag alpha beta
alpha = forward crf dag
beta = backward crf dag
zxF = zxAlpha dag alpha
zxB = zxBeta dag beta
goodAndBad :: Md.Model -> DAG a (X, Y) -> (Int, Int)
goodAndBad crf dag =
F.foldl' gather (0, 0) $ DAG.zipE labels labels'
where
xs = fmap fst dag
ys = fmap snd dag
labels = fmap (best . C.unY) ys
best zs
| null zs = Nothing
| otherwise = Just . fst $ maximumBy (compare `on` snd) zs
labels' = fmap Just $ tag crf xs
gather (good, bad) (x, y)
| x == y = (good + 1, bad)
| otherwise = (good, bad + 1)
goodAndBad' :: Md.Model -> [DAG a (X, Y)] -> (Int, Int)
goodAndBad' crf dataset =
let add (g, b) (g', b') = (g + g', b + b')
in F.foldl' add (0, 0) [goodAndBad crf x | x <- dataset]
accuracy :: Md.Model -> [DAG a (X, Y)] -> Double
accuracy crf dataset =
let k = numCapabilities
parts = partition k dataset
xs = parMap rseq (goodAndBad' crf) parts
(good, bad) = F.foldl' add (0, 0) xs
add (g, b) (g', b') = (g + g', b + b')
in fromIntegral good / fromIntegral (good + bad)
almostEq :: L.LogFloat -> L.LogFloat -> Bool
almostEq x0 y0
| isZero x && isZero y = True
| otherwise = 1.0 - eps < z && z < 1.0 + eps
where
x = L.logFromLogFloat x0
y = L.logFromLogFloat y0
z = x / y
isZero :: (Fractional t, Ord t) => t -> Bool
isZero x = abs x < eps
eps :: Fractional t => t
eps = 0.000001