{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Rank2Types #-}


module Data.CRF.Chain2.Tiers.DAG.Probs
( probability
, likelihood
, parLikelihood
) where


import           GHC.Conc (numCapabilities)

import           Control.Applicative ((<$>))
import qualified Control.Arrow as Arr
import qualified Control.Parallel as Par
import qualified Control.Parallel.Strategies as Par

import qualified Data.Number.LogFloat as L
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.Array as A
import qualified Data.Set as S
import           Data.Maybe (fromJust, maybeToList)
import qualified Data.MemoCombinators as Memo
import qualified Data.List as List
import           Data.Function (on)
import qualified Data.Foldable as F

import           Data.DAG (EdgeID, DAG)
import qualified Data.DAG as DAG

import qualified Data.CRF.Chain2.Tiers.Core as C
import           Data.CRF.Chain2.Tiers.Core (X, Y, Ob, Cb, CbIx)
import qualified Data.CRF.Chain2.Tiers.Model as Md
import           Data.CRF.Chain2.Tiers.Util (partition)
import           Data.CRF.Chain2.Tiers.DAG.Feature (EdgeIx(..))
import qualified Data.CRF.Chain2.Tiers.DAG.Feature as Ft

import           Data.CRF.Chain2.Tiers.DAG.Inference
                 (AccF, Pos(..), simplify, complicate, ProbArray)
import qualified Data.CRF.Chain2.Tiers.DAG.Inference as I

import Debug.Trace (trace)


--------------------------------------------
-- Indexing
---------------------------------------------


-- | List of observations on the given edge of the sentence.
obList :: DAG a (X, Y) -> EdgeID -> [Ob]
obList dag i = C.unX . fst $ DAG.edgeLabel i dag
{-# INLINE obList #-}


-- | Vector of labels and the corresponding probabilities on the given edge of
-- the sentence.
lbVec :: DAG a (X, Y) -> EdgeID -> V.Vector (Cb, Double)
lbVec dag edgeID =
  case DAG.edgeLabel edgeID dag of
    (_, y) -> C._unY y
{-# INLINE lbVec #-}


-- | Number of potential labels on the given edge of the sentence.
lbNum :: DAG a (X, Y) -> EdgeID -> Int
lbNum dag = V.length . lbVec dag
{-# INLINE lbNum #-}


-- | Label on the given edge and on the given position.
lbOn :: DAG a (X, Y) -> EdgeID -> CbIx -> Maybe (Cb, L.LogFloat)
lbOn dag i = fmap (Arr.second L.logToLogFloat) . (lbVec dag i V.!?)
{-# INLINE lbOn #-}


-- | List of label indices at the given edge.
lbIxs :: DAG a (X, Y) -> EdgeID -> [CbIx]
lbIxs dag i = [0 .. lbNum dag i - 1]
{-# INLINE lbIxs #-}


-- | The list of EdgeIx's corresponding to the given edge.
edgeIxs :: DAG a (X, Y) -> EdgeID -> [EdgeIx]
edgeIxs dag i =
  [ EdgeIx {edgeID=i, lbIx=u}
  | u <- lbIxs dag i ]


--------------------------------------------
-- Indexing advanced
---------------------------------------------


-- | The list of EdgeIx's corresponding to the previous edges. If the argument
-- edgeID is `Nothing` or if the list of previous edges is empty, the result
-- will be a singleton list containing `Nothing` (which represents a special
-- out-of-bounds EdgeIx).
prevEdgeIxs :: DAG a (X, Y) -> Maybe EdgeID -> [Maybe EdgeIx]
prevEdgeIxs _ Nothing = [Nothing]
prevEdgeIxs dag (Just i)
  | null js = [Nothing]
  | otherwise = Just <$>
    [ EdgeIx {edgeID=j, lbIx=u}
    | j <- js, u <- lbIxs dag j ]
  where js = DAG.prevEdges i dag


-- | Obtain the list of final EdgeIxs.
finalEdgeIxs :: DAG a (X, Y) -> [EdgeIx]
finalEdgeIxs dag = concat
  [ edgeIxs dag i
  | i <- DAG.dagEdges dag
  , DAG.isFinalEdge i dag ]


---------------------------------------------
-- Feature alternatives
---------------------------------------------


-- | Observation features on a given position and with respect
-- to a given label (determined by index).
obFeatsOn :: DAG a (X, Y) -> EdgeIx -> [C.Feat]
obFeatsOn dag EdgeIx{..} = concat
  [ C.obFeats o e
  | (e, _prob) <- maybeToList $ lbOn dag edgeID lbIx
  , o <- obList dag edgeID ]
{-# INLINE obFeatsOn #-}


-- | Probability of the given EdgeIx. WARNING: returns 0 if the Ix is not in the
-- DAG (and we rely on this behavior)!
probOn :: DAG a (X, Y) -> EdgeIx -> L.LogFloat
probOn dag EdgeIx{..} =
  maybe 0 id $ snd <$> lbOn dag edgeID lbIx
{-# INLINE probOn #-}


-- | Transition features on a given position and with respect to given labels
-- (determined by indexes).
--
-- TODO: almost the same as `Feature.trFeatsOn`.
trFeatsOn
  :: DAG a (X, Y)
  -> Maybe EdgeIx -- ^ Current EdgeIx
  -> Maybe EdgeIx -- ^ Previous EdgeIx
  -> Maybe EdgeIx -- ^ One before the previous EdgeIx
  -> [C.Feat]
trFeatsOn dag u' v' w' = doit
  (lbOn' =<< u')
  (lbOn' =<< v')
  (lbOn' =<< w')
  where
    lbOn' EdgeIx{..} = fst <$> lbOn dag edgeID lbIx
    doit (Just u) (Just v) (Just w) = C.trFeats3 u v w
    doit (Just u) (Just v) _        = C.trFeats2 u v
    doit (Just u) _ _               = C.trFeats1 u
    doit _ _ _                      = []
{-# INLINE trFeatsOn #-}


----------------------------------------------------
-- Potential
----------------------------------------------------


-- | Observation potential on a given position and a given label (identified by
-- index), multiplied by the label's probability.
onWord :: Md.Model -> DAG a (X, Y) -> EdgeIx -> L.LogFloat
onWord crf dag ix
  = (probOn dag ix *)
  . L.product
  . map (Md.phi crf)
  . obFeatsOn dag
  $ ix
{-# INLINE onWord #-}


-- | Transition potential on a given position and a
-- given labels (identified by indexes).
onTransition
  :: Md.Model
  -> DAG a (X, Y)
  -> Maybe EdgeIx
  -> Maybe EdgeIx
  -> Maybe EdgeIx
  -> L.LogFloat
onTransition crf dag u w v
  = L.product
  . map (Md.phi crf)
  $ trFeatsOn dag u w v
{-# INLINE onTransition #-}


---------------------------------------------
-- A bit more complex stuff
---------------------------------------------


-- | Forward table computation.
forward :: AccF -> Md.Model -> DAG a (X, Y) -> ProbArray
forward acc crf dag =
  alpha
  where
    alpha = I.memoProbArray dag alpha'
    alpha' Beg Beg = 1.0
    alpha' End End = acc
      [ alpha End (Mid w)
        -- below, onTransition equals currently to 1; in general, there could be
        -- some related transition features, though.
        * onTransition crf dag Nothing Nothing (Just w)
      | w <- finalEdgeIxs dag ]
    alpha' u v = acc
      [ alpha v w * psi' u
        * onTransition crf dag (simplify u) (simplify v) (simplify w)
      | w <- complicate Beg <$> prevEdgeIxs dag (edgeID <$> simplify v) ]
    psi' u = case u of
      Mid x -> psi x
      _ -> 1.0
    psi = I.memoEdgeIx dag $ onWord crf dag


-- | Probability of the given DAG in the given model.
probability :: Md.Model -> DAG a (X, Y) -> L.LogFloat
probability crf dag =
  zxAlpha (forward L.sum crf dag) / normFactor
  where
    zxAlpha pa = pa End End
    normFactor = I.zx crf (fmap fst dag)


-- | Log-likelihood of the given dataset (no parallelization).
likelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
likelihood crf = L.product . map (probability crf)


-- | Log-likelihood of the given dataset (parallelized version).
parLikelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
parLikelihood crf dataset =
  let k = numCapabilities
      parts = partition k dataset
      probs = Par.parMap Par.rseq (likelihood crf) parts
  in  L.product probs