{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}
module Data.CRF.Chain1.Constrained.DAG.Probs
( probability
, likelihood
, parLikelihood
) where
import Control.Applicative ((<$>))
import Data.Maybe (catMaybes)
import Data.List (maximumBy, sort, sortBy)
import Data.Function (on)
import qualified Data.Set as S
import qualified Data.Array as A
import qualified Data.Vector.Unboxed as U
import qualified Data.Foldable as F
import Control.Parallel.Strategies (rseq, parMap)
import Control.Parallel (par, pseq)
import GHC.Conc (numCapabilities)
import qualified Data.Number.LogFloat as L
import Data.DAG (EdgeID, DAG)
import qualified Data.DAG as DAG
import qualified Data.CRF.Chain1.Constrained.DP as DP
import Data.CRF.Chain1.Constrained.Util (partition)
import qualified Data.CRF.Chain1.Constrained.Model as Md
import Data.CRF.Chain1.Constrained.Core (X, Y, Lb, AVec)
import qualified Data.CRF.Chain1.Constrained.Core as C
import qualified Data.CRF.Chain1.Constrained.Intersect as I
import Data.CRF.Chain1.Constrained.DAG.Feature (featuresIn)
import qualified Data.CRF.Chain1.Constrained.DAG.Inference as Inf
type LbIx = Int
type ProbArray = EdgeID -> LbIx -> L.LogFloat
lbVec :: Md.Model -> DAG a (X, Y) -> EdgeID -> AVec (Lb, Double)
lbVec crf dag edgeID =
case DAG.edgeLabel edgeID dag of
(_, y) -> C._unY y
{-# INLINE lbVec #-}
lbNum :: Md.Model -> DAG a (X, Y) -> EdgeID -> Int
lbNum crf dag = U.length . C.unAVec . lbVec crf dag
{-# INLINE lbNum #-}
lbOn :: Md.Model -> DAG a (X, Y) -> EdgeID -> LbIx -> (Lb, Double)
lbOn crf dag = (U.!) . C.unAVec . lbVec crf dag
{-# INLINE lbOn #-}
lbIxs :: Md.Model -> DAG a (X, Y) -> EdgeID -> [(LbIx, (Lb, Double))]
lbIxs crf dag = zip [0..] . U.toList . C.unAVec . lbVec crf dag
{-# INLINE lbIxs #-}
computePsi :: Md.Model -> DAG a (X, Y) -> EdgeID -> LbIx -> L.LogFloat
computePsi crf dag edgeID
= (A.!)
. A.accumArray (*) 1 bounds
$ proTab ++ obsTab
where
bounds = (0, lbNum crf dag edgeID - 1)
obsTab =
[ (lbIx, Md.valueL crf featIx)
| ob <- (C.unX . fst) (DAG.edgeLabel edgeID dag)
, (lbIx, featIx) <-
I.intersect (Md.obIxs crf ob) (xify $ lbVec crf dag edgeID) ]
proTab =
[ (lbIx, L.logFloat prob)
| (lbIx, (_lb, prob)) <- lbIxs crf dag edgeID ]
forward :: Md.Model -> DAG a (X, Y) -> ProbArray
forward crf dag = alpha where
alpha = DP.flexible2 bounds boundsOn
(\t i -> withMem (computePsi crf dag i) t i)
bounds = (DAG.minEdge dag, DAG.maxEdge dag + 1)
boundsOn i
| i == snd bounds = (0, 0)
| otherwise = (0, lbNum crf dag i - 1)
initialSet = S.fromList
[ i
| i <- DAG.dagEdges dag
, DAG.isInitialEdge i dag ]
withMem psi alpha i
| i == snd bounds = const u'
| i `S.member` initialSet = \j ->
let (x, _) = lbOn crf dag i j
in psi j * Md.sgValue crf x
| otherwise = \j ->
let (x, _) = lbOn crf dag i j
in psi j * ((u - v x) + w x)
where
u = safeSum
[ alpha iMinus1 k
| iMinus1 <- DAG.prevEdges i dag
, (k, _) <- lbIxs crf dag iMinus1 ]
v x = safeSum
[ alpha iMinus1 k
| iMinus1 <- DAG.prevEdges i dag
, (k, _) <-
I.intersect (Md.prevIxs crf x) (xify $ lbVec crf dag iMinus1) ]
w x = safeSum
[ alpha iMinus1 k * Md.valueL crf ix
| iMinus1 <- DAG.prevEdges i dag
, (k, ix) <- I.intersect (Md.prevIxs crf x) (xify $ lbVec crf dag iMinus1) ]
u' = safeSum
[ alpha iMinus1 k
| iMinus1 <- DAG.dagEdges dag
, DAG.isFinalEdge iMinus1 dag
, (k, _) <- lbIxs crf dag iMinus1 ]
probability :: Md.Model -> DAG a (X, Y) -> L.LogFloat
probability crf dag =
zxAlpha (forward crf dag) / normFactor
where
zxAlpha alpha = alpha (DAG.maxEdge dag + 1) 0
normFactor = Inf.zx crf (fmap fst dag)
parLikelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
parLikelihood crf dataset =
let k = numCapabilities
parts = partition k dataset
probs = parMap rseq (likelihood crf) parts
in L.product probs
likelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
likelihood crf = L.product . map (probability crf)
xify :: (U.Unbox x, U.Unbox y) => C.AVec (x, y) -> C.AVec x
xify = C.AVec . U.map fst . C.unAVec
{-# INLINE xify #-}
safeSum :: [L.LogFloat] -> L.LogFloat
safeSum [] = 0
safeSum xs = L.sum xs
{-#INLINE safeSum #-}