{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Rank2Types #-}
module Data.CRF.Chain2.Tiers.DAG.Inference
( tag
, tag'
, tagK
, fastTag
, fastTag'
, marginals
, marginals'
, ProbType (..)
, probs
, probs'
, accuracy
, expectedFeaturesIn
, zx
, zx'
, AccF
, ProbArray
, Pos (..)
, simplify
, complicate
, memoProbArray
, memoEdgeIx
) where
import GHC.Conc (numCapabilities)
import Control.Applicative ((<$>))
import qualified Control.Parallel as Par
import qualified Control.Parallel.Strategies as Par
import qualified Data.Number.LogFloat as L
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust)
import qualified Data.MemoCombinators as Memo
import qualified Data.List as List
import Data.Function (on)
import qualified Data.Foldable as F
import Data.DAG (EdgeID, DAG)
import qualified Data.DAG as DAG
import qualified Data.CRF.Chain2.Tiers.Core as C
import Data.CRF.Chain2.Tiers.Core (X, Y, Cb, CbIx)
import qualified Data.CRF.Chain2.Tiers.Model as Md
import Data.CRF.Chain2.Tiers.Util (partition)
import Data.CRF.Chain2.Tiers.DAG.Feature (EdgeIx(..))
import qualified Data.CRF.Chain2.Tiers.DAG.Feature as Ft
type AccF = [L.LogFloat] -> L.LogFloat
data Pos
= Beg
| Mid EdgeIx
| End
deriving (Show, Eq, Ord)
simplify :: Pos -> Maybe EdgeIx
simplify (Mid x) = Just x
simplify Beg = Nothing
simplify End = Nothing
complicate :: Pos -> Maybe EdgeIx -> Pos
complicate df Nothing = df
complicate _ (Just x) = Mid x
type ProbArray = Pos -> Pos -> L.LogFloat
memoProbArray :: DAG a b -> ProbArray -> ProbArray
memoProbArray dag =
let memo = memoPos dag
in Memo.memo2 memo memo
memoPos :: DAG a b -> Memo.Memo Pos
memoPos dag f =
table (f Beg) (memo (f . Mid)) (f End)
where
memo = memoEdgeIx dag
table b _ _ Beg = b
table _ m _ (Mid x) = m x
table _ _ e End = e
memoEdgeIx :: DAG a b -> Memo.Memo EdgeIx
memoEdgeIx dag =
Memo.wrap fromPair toPair memoPair
where
memoPair = Memo.pair memoEdgeID Memo.integral
memoEdgeID = Memo.unsafeArrayRange (DAG.minEdge dag, DAG.maxEdge dag)
fromPair (x, y) = EdgeIx x y
toPair (EdgeIx x y) = (x, y)
onWord :: Md.Model -> DAG a X -> EdgeIx -> L.LogFloat
onWord crf dag
= L.product
. map (Md.phi crf)
. Ft.obFeatsOn dag
{-# INLINE onWord #-}
onTransition
:: Md.Model
-> DAG a X
-> Maybe EdgeIx
-> Maybe EdgeIx
-> Maybe EdgeIx
-> L.LogFloat
onTransition crf dag u w v
= L.product
. map (Md.phi crf)
$ Ft.trFeatsOn dag u w v
{-# INLINE onTransition #-}
fastTag :: Md.Model -> DAG a X -> DAG a (Maybe CbIx)
fastTag crf dag =
DAG.mapE label dag
where
label edgeID _ = M.lookup edgeID selSet
alpha = forward maximum crf dag
selSet = rewind dag alpha
fastTag' :: Md.Model -> DAG a X -> DAG a (Maybe Cb)
fastTag' crf dag
= fmap (\(x, mayIx) -> C.lbAt x <$> mayIx)
$ DAG.zipE dag (fastTag crf dag)
rewind
:: DAG a X
-> ProbArray
-> M.Map EdgeID CbIx
rewind dag alpha =
best M.empty End
where
best m u = pick m $ argmax Beg [(w, alpha u w) | w <- prev u]
prev End = Mid <$> Ft.finalEdgeIxs dag
prev (Mid u) = complicate Beg <$> Ft.prevEdgeIxs dag (Just $ edgeID u)
prev _ = error "DAG.Inference.rewind: impossible 1 happened"
pick m (Mid u) = best (M.insert (edgeID u) (lbIx u) m) (Mid u)
pick m Beg = m
pick _ _ = error "DAG.Inference.rewind: impossible 2 happened"
argmax :: Ord v => k -> [(k, v)] -> k
argmax _def (x:xs) =
go (fst x) (snd x) xs
where
go k v ((k', v') : rest)
| v >= v' = go k v rest
| otherwise = go k' v' rest
go k _ [] = k
argmax def [] = def
{-# INLINE argmax #-}
forward :: AccF -> Md.Model -> DAG a X -> ProbArray
forward acc crf dag =
alpha
where
alpha = memoProbArray dag alpha'
alpha' Beg Beg = 1.0
alpha' End End = acc
[ alpha End (Mid w)
* onTransition crf dag Nothing Nothing (Just w)
| w <- Ft.finalEdgeIxs dag ]
alpha' u v = acc
[ alpha v w * psi' u
* onTransition crf dag (simplify u) (simplify v) (simplify w)
| w <- complicate Beg <$> Ft.prevEdgeIxs dag (edgeID <$> simplify v) ]
psi' u = case u of
Mid x -> psi x
_ -> 1.0
psi = memoEdgeIx dag $ onWord crf dag
backward :: AccF -> Md.Model -> DAG a X -> ProbArray
backward acc crf dag =
beta
where
beta = memoProbArray dag beta'
beta' End End = 1.0
beta' Beg Beg = acc
[ beta (Mid u) Beg * psi u
* onTransition crf dag (Just u) Nothing Nothing
| u <- Ft.initialEdgeIxs dag ]
beta' v w = acc
[ beta u v * psi' u
* onTransition crf dag (simplify u) (simplify v) (simplify w)
| u <- complicate End <$> Ft.nextEdgeIxs dag (edgeID <$> simplify v) ]
psi' u = case u of
Mid x -> psi x
_ -> 1.0
psi = memoEdgeIx dag $ onWord crf dag
zx :: Md.Model -> DAG a X -> L.LogFloat
zx crf = zxAlpha . forward L.sum crf
zxAlpha :: ProbArray -> L.LogFloat
zxAlpha pa = pa End End
zx' :: Md.Model -> DAG a X -> L.LogFloat
zx' crf = zxBeta . backward L.sum crf
zxBeta :: ProbArray -> L.LogFloat
zxBeta pa = pa Beg Beg
edgeProb3
:: Md.Model
-> DAG a X
-> (EdgeIx -> L.LogFloat)
-> ProbArray
-> ProbArray
-> EdgeIx
-> Maybe EdgeIx
-> Maybe EdgeIx
-> L.LogFloat
edgeProb3 crf dag psi alpha beta u0 v0 w0
= alpha v w
* beta u v
* psi u0
* onTransition crf dag (Just u0) v0 w0
/ zxBeta beta
where
u = Mid u0
v = complicate Beg v0
w = complicate Beg w0
edgeProb2
:: ProbArray
-> ProbArray
-> EdgeIx
-> Maybe EdgeIx
-> L.LogFloat
edgeProb2 alpha beta u0 v0 =
alpha u v * beta u v / zxAlpha alpha
where
u = Mid u0
v = complicate Beg v0
edgeProb1
:: AccF
-> DAG a X
-> ProbArray
-> ProbArray
-> EdgeIx
-> L.LogFloat
edgeProb1 acc dag alpha beta u = acc
[ edgeProb2 alpha beta u v
| v <- Ft.prevEdgeIxs dag (Just $ edgeID u) ]
marginals :: Md.Model -> DAG a X -> DAG a [(CbIx, L.LogFloat)]
marginals crf dag =
DAG.mapE label dag
where
label edgeID _ =
[ (Ft.lbIx edgeIx, prob1 edgeIx)
| edgeIx <- Ft.edgeIxs dag edgeID ]
prob1 = edgeProb1 L.sum dag alpha beta
alpha = forward L.sum crf dag
beta = backward L.sum crf dag
marginals' :: Md.Model -> DAG a X -> DAG a [(Cb, L.LogFloat)]
marginals' crf dag = mergeProbs dag (marginals crf dag)
data ProbType
= Marginals
| MaxProbs
probs :: ProbType -> Md.Model -> DAG a X -> DAG a [(CbIx, L.LogFloat)]
probs probTyp crf dag =
DAG.mapE label dag
where
label edgeID _ =
[ (Ft.lbIx edgeIx, prob1 edgeIx)
| edgeIx <- Ft.edgeIxs dag edgeID ]
prob1 = edgeProb1 acc dag alpha beta
alpha = forward acc crf dag
beta = backward acc crf dag
acc = case probTyp of
Marginals -> L.sum
MaxProbs -> maximum
probs' :: ProbType -> Md.Model -> DAG a X -> DAG a [(Cb, L.LogFloat)]
probs' typ crf dag = mergeProbs dag (probs typ crf dag)
mergeProbs :: DAG a X -> DAG a [(CbIx, L.LogFloat)] -> DAG a [(Cb, L.LogFloat)]
mergeProbs dag
= fmap lbAt
. DAG.zipE dag
where
lbAt (x, ys) =
[ (C.lbAt x cbIx, pr)
| (cbIx, pr) <- ys ]
tagK :: Int -> Md.Model -> DAG a X -> DAG a [(CbIx, L.LogFloat)]
tagK k crf dag = fmap
( take k
. reverse
. List.sortBy (compare `on` snd)
) (marginals crf dag)
tag :: Md.Model -> DAG a X -> DAG a CbIx
tag crf = fmap (fst . head) . tagK 1 crf
tag' :: Md.Model -> DAG a X -> DAG a Cb
tag' crf dag
= fmap (uncurry C.lbAt)
$ DAG.zipE dag (tag crf dag)
expectedFeaturesOn
:: Md.Model
-> DAG a X
-> ProbArray
-> ProbArray
-> EdgeID
-> [(C.Feat, L.LogFloat)]
expectedFeaturesOn crf dag alpha beta edgeID =
fs1 ++ fs3
where
psi = memoEdgeIx dag $ onWord crf dag
prob1 = edgeProb1 L.sum dag alpha beta
prob3 = edgeProb3 crf dag psi alpha beta
fs1 =
[ (ft, prob)
| edgeIx <- Ft.edgeIxs dag edgeID
, let prob = prob1 edgeIx
, ft <- Ft.obFeatsOn dag edgeIx ]
fs3 =
[ (ft, prob)
| u <- Just <$> Ft.edgeIxs dag edgeID
, v <- Ft.prevEdgeIxs dag (Ft.edgeID <$> u)
, w <- Ft.prevEdgeIxs dag (Ft.edgeID <$> v)
, let prob = prob3 (fromJust u) v w
, ft <- Ft.trFeatsOn dag u v w ]
expectedFeaturesIn
:: Md.Model
-> DAG a X
-> [(C.Feat, L.LogFloat)]
expectedFeaturesIn crf dag = zxF `Par.par` zxB `Par.pseq` zxF `Par.pseq`
concat [expectedOn edgeID | edgeID <- DAG.dagEdges dag]
where
expectedOn = expectedFeaturesOn crf dag alpha beta
alpha = forward L.sum crf dag
beta = backward L.sum crf dag
zxF = zxAlpha alpha
zxB = zxBeta beta
goodAndBad :: Md.Model -> DAG a (X, Y) -> (Int, Int)
goodAndBad crf dag =
F.foldl' gather (0, 0) $ DAG.zipE labels labels'
where
gather (good, bad) results =
if consistent results
then (good + 1, bad)
else (good, bad + 1)
consistent results = case results of
(Just xs, Just ys) -> (not . S.null) (S.intersection xs ys)
(Nothing, Nothing) -> True
_ -> False
labels' = fmap best $ probs' MaxProbs crf (fmap fst dag)
labels = fmap (best . C.unY) (fmap snd dag)
best zs
| null zs = Nothing
| otherwise =
let maxProb = maximum (map snd zs)
in if maxProb < eps
then Nothing
else Just
. S.fromList . map fst
. filter ((>= maxProb - eps) . snd)
$ zs
eps = 1.0e-9
goodAndBad' :: Md.Model -> [DAG a (X, Y)] -> (Int, Int)
goodAndBad' crf dataset =
let add (g, b) (g', b') = (g + g', b + b')
in F.foldl' add (0, 0) [goodAndBad crf x | x <- dataset]
accuracy :: Md.Model -> [DAG a (X, Y)] -> Double
accuracy crf dataset =
let k = numCapabilities
parts = partition k dataset
xs = Par.parMap Par.rseq (goodAndBad' crf) parts
(good, bad) = F.foldl' add (0, 0) xs
add (g, b) (g', b') = (g + g', b + b')
in fromIntegral good / fromIntegral (good + bad)