{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Rank2Types #-}
module Data.CRF.Chain2.Tiers.DAG.Probs
( probability
, likelihood
, parLikelihood
) where
import GHC.Conc (numCapabilities)
import Control.Applicative ((<$>))
import qualified Control.Arrow as Arr
import qualified Control.Parallel as Par
import qualified Control.Parallel.Strategies as Par
import qualified Data.Number.LogFloat as L
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.Array as A
import qualified Data.Set as S
import Data.Maybe (fromJust, maybeToList)
import qualified Data.MemoCombinators as Memo
import qualified Data.List as List
import Data.Function (on)
import qualified Data.Foldable as F
import Data.DAG (EdgeID, DAG)
import qualified Data.DAG as DAG
import qualified Data.CRF.Chain2.Tiers.Core as C
import Data.CRF.Chain2.Tiers.Core (X, Y, Ob, Cb, CbIx)
import qualified Data.CRF.Chain2.Tiers.Model as Md
import Data.CRF.Chain2.Tiers.Util (partition)
import Data.CRF.Chain2.Tiers.DAG.Feature (EdgeIx(..))
import qualified Data.CRF.Chain2.Tiers.DAG.Feature as Ft
import Data.CRF.Chain2.Tiers.DAG.Inference
(AccF, Pos(..), simplify, complicate, ProbArray)
import qualified Data.CRF.Chain2.Tiers.DAG.Inference as I
import Debug.Trace (trace)
obList :: DAG a (X, Y) -> EdgeID -> [Ob]
obList dag i = C.unX . fst $ DAG.edgeLabel i dag
{-# INLINE obList #-}
lbVec :: DAG a (X, Y) -> EdgeID -> V.Vector (Cb, Double)
lbVec dag edgeID =
case DAG.edgeLabel edgeID dag of
(_, y) -> C._unY y
{-# INLINE lbVec #-}
lbNum :: DAG a (X, Y) -> EdgeID -> Int
lbNum dag = V.length . lbVec dag
{-# INLINE lbNum #-}
lbOn :: DAG a (X, Y) -> EdgeID -> CbIx -> Maybe (Cb, L.LogFloat)
lbOn dag i = fmap (Arr.second L.logToLogFloat) . (lbVec dag i V.!?)
{-# INLINE lbOn #-}
lbIxs :: DAG a (X, Y) -> EdgeID -> [CbIx]
lbIxs dag i = [0 .. lbNum dag i - 1]
{-# INLINE lbIxs #-}
edgeIxs :: DAG a (X, Y) -> EdgeID -> [EdgeIx]
edgeIxs dag i =
[ EdgeIx {edgeID=i, lbIx=u}
| u <- lbIxs dag i ]
prevEdgeIxs :: DAG a (X, Y) -> Maybe EdgeID -> [Maybe EdgeIx]
prevEdgeIxs _ Nothing = [Nothing]
prevEdgeIxs dag (Just i)
| null js = [Nothing]
| otherwise = Just <$>
[ EdgeIx {edgeID=j, lbIx=u}
| j <- js, u <- lbIxs dag j ]
where js = DAG.prevEdges i dag
finalEdgeIxs :: DAG a (X, Y) -> [EdgeIx]
finalEdgeIxs dag = concat
[ edgeIxs dag i
| i <- DAG.dagEdges dag
, DAG.isFinalEdge i dag ]
obFeatsOn :: DAG a (X, Y) -> EdgeIx -> [C.Feat]
obFeatsOn dag EdgeIx{..} = concat
[ C.obFeats o e
| (e, _prob) <- maybeToList $ lbOn dag edgeID lbIx
, o <- obList dag edgeID ]
{-# INLINE obFeatsOn #-}
probOn :: DAG a (X, Y) -> EdgeIx -> L.LogFloat
probOn dag EdgeIx{..} =
maybe 0 id $ snd <$> lbOn dag edgeID lbIx
{-# INLINE probOn #-}
trFeatsOn
:: DAG a (X, Y)
-> Maybe EdgeIx
-> Maybe EdgeIx
-> Maybe EdgeIx
-> [C.Feat]
trFeatsOn dag u' v' w' = doit
(lbOn' =<< u')
(lbOn' =<< v')
(lbOn' =<< w')
where
lbOn' EdgeIx{..} = fst <$> lbOn dag edgeID lbIx
doit (Just u) (Just v) (Just w) = C.trFeats3 u v w
doit (Just u) (Just v) _ = C.trFeats2 u v
doit (Just u) _ _ = C.trFeats1 u
doit _ _ _ = []
{-# INLINE trFeatsOn #-}
onWord :: Md.Model -> DAG a (X, Y) -> EdgeIx -> L.LogFloat
onWord crf dag ix
= (probOn dag ix *)
. L.product
. map (Md.phi crf)
. obFeatsOn dag
$ ix
{-# INLINE onWord #-}
onTransition
:: Md.Model
-> DAG a (X, Y)
-> Maybe EdgeIx
-> Maybe EdgeIx
-> Maybe EdgeIx
-> L.LogFloat
onTransition crf dag u w v
= L.product
. map (Md.phi crf)
$ trFeatsOn dag u w v
{-# INLINE onTransition #-}
forward :: AccF -> Md.Model -> DAG a (X, Y) -> ProbArray
forward acc crf dag =
alpha
where
alpha = I.memoProbArray dag alpha'
alpha' Beg Beg = 1.0
alpha' End End = acc
[ alpha End (Mid w)
* onTransition crf dag Nothing Nothing (Just w)
| w <- finalEdgeIxs dag ]
alpha' u v = acc
[ alpha v w * psi' u
* onTransition crf dag (simplify u) (simplify v) (simplify w)
| w <- complicate Beg <$> prevEdgeIxs dag (edgeID <$> simplify v) ]
psi' u = case u of
Mid x -> psi x
_ -> 1.0
psi = I.memoEdgeIx dag $ onWord crf dag
probability :: Md.Model -> DAG a (X, Y) -> L.LogFloat
probability crf dag =
zxAlpha (forward L.sum crf dag) / normFactor
where
zxAlpha pa = pa End End
normFactor = I.zx crf (fmap fst dag)
likelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
likelihood crf = L.product . map (probability crf)
parLikelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
parLikelihood crf dataset =
let k = numCapabilities
parts = partition k dataset
probs = Par.parMap Par.rseq (likelihood crf) parts
in L.product probs