{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}


-- | Inference with CRFs.


module Data.CRF.Chain1.Constrained.DAG.Inference
( tag
, tagK
, marginals
, accuracy
, expectedFeaturesIn
, zx
, zx'

-- , probability
-- , likelihood

-- * Internals
, computePsi
) where


import Control.Applicative ((<$>))
import Data.Maybe (catMaybes)
import Data.List (maximumBy, sortBy)
import Data.Function (on)
import qualified Data.Set as S
import qualified Data.Array as A
-- import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.Foldable as F

import Control.Parallel.Strategies (rseq, parMap)
import Control.Parallel (par, pseq)
import GHC.Conc (numCapabilities)
import qualified Data.Number.LogFloat as L

-- import           Data.CRF.Chain1.Constrained.DAG.Dataset.Internal (EdgeID, DAG)
-- import qualified Data.CRF.Chain1.Constrained.DAG.Dataset.Internal as DAG
import           Data.DAG (EdgeID, DAG)
import qualified Data.DAG as DAG

import qualified Data.CRF.Chain1.Constrained.DP as DP
import           Data.CRF.Chain1.Constrained.Util (partition)
import qualified Data.CRF.Chain1.Constrained.Model as Md

import           Data.CRF.Chain1.Constrained.Core (X, Y, Lb, AVec)
import qualified Data.CRF.Chain1.Constrained.Core as C
import qualified Data.CRF.Chain1.Constrained.Intersect as I

-- import           Data.CRF.Chain1.Constrained.DAG.Feature (featuresIn)

import Debug.Trace (trace)


---------------------------------------------
-- Util Types
---------------------------------------------


-- | Label index.
type LbIx       = Int


-- | The probability array assigns some probability to each label (represented
-- by its index) which can be assigned to a given edge (represented by its
-- `EdgeID`).
type ProbArray  = EdgeID -> LbIx -> L.LogFloat


---------------------------------------------
-- Summing
---------------------------------------------


-- -- | Numerically safer summing.
-- safeSum :: (Ord a, Num a) => [a] -> a
-- -- safeSum = sum . sort
-- safeSum = sum
-- {-#INLINE safeSum #-}


-- | Numerically safer summing.
safeSum :: [L.LogFloat] -> L.LogFloat
safeSum [] = 0
safeSum xs = L.sum xs
{-#INLINE safeSum #-}


---------------------------------------------
-- Some basic functions.
---------------------------------------------


-- | Vector of potential labels on the given edge of the sentence.
lbVec :: Md.Model -> DAG a X -> EdgeID -> AVec Lb
lbVec crf dag edgeID = case DAG.edgeLabel edgeID dag of
  C.X _   -> (Md.r0 crf)
  C.R _ r -> r
{-# INLINE lbVec #-}


-- | Number of potential labels on the given edge of the sentence.
lbNum :: Md.Model -> DAG a X -> EdgeID -> Int
lbNum crf dag = (U.length . C.unAVec) . lbVec crf dag
{-# INLINE lbNum #-}


-- | Potential label on the given vector position.
lbOn :: Md.Model -> X -> LbIx -> Lb
lbOn crf (C.X _)   = (C.unAVec (Md.r0 crf) U.!)
lbOn _   (C.R _ r) = (C.unAVec r U.!)
{-# INLINE lbOn #-}


-- | Potential labels on the given sentence edge (as in `lbVec`), accompanied
-- with the corresponding indexes. I.e., each label `Lb` is accompanied with a
-- number, from [0..], corresponding to its index in the vector of labels
-- obtained with `lbVec`.
lbIxs :: Md.Model -> DAG a X -> EdgeID -> [(LbIx, Lb)]
lbIxs crf dag = zip [0..] . U.toList . C.unAVec . lbVec crf dag
{-# INLINE lbIxs #-}


---------------------------------------------
-- A bit more complex stuff.
---------------------------------------------


-- | Compute the table of potential products associated with observation
-- features for the given sentence edge.
computePsi :: Md.Model -> DAG a X -> EdgeID -> LbIx -> L.LogFloat
computePsi crf dag i = (A.!) $ A.accumArray (*) 1 bounds
    [ (k, Md.valueL crf ix)
    | ob <- C.unX (DAG.edgeLabel i dag)
    , (k, ix) <- I.intersect (Md.obIxs crf ob) (lbVec crf dag i) ]
  where
    bounds = (0, lbNum crf dag i - 1)


-- | Equivalent to `computePsi`, but memoizes additionally on `EdgeID`s.
computePsi'
  :: Md.Model -> DAG a X
  -> EdgeID -> LbIx
  -> L.LogFloat
computePsi' crf dag =
  (array A.!)
  where
    bounds = (DAG.minEdge dag, DAG.maxEdge dag)
    array = A.array bounds
      [ (i, computePsi crf dag i)
      | i <- A.range bounds ]


-- | Forward table computation.
forward :: Md.Model -> DAG a X -> ProbArray
forward crf dag = alpha where
  alpha = DP.flexible2 bounds boundsOn
    (\t i -> withMem (computePsi crf dag i) t i)
  bounds = (DAG.minEdge dag, DAG.maxEdge dag + 1)
  boundsOn i
    | i == snd bounds = (0, 0)
    | otherwise = (0, lbNum crf dag i - 1)
  -- set of initial edges
  initialSet = S.fromList
    [ i
    | i <- DAG.dagEdges dag
    , DAG.isInitialEdge i dag ]
  withMem psi alpha' i
    | i == snd bounds = const u'
    | i `S.member` initialSet = \j ->
        let x = lbOn crf (DAG.edgeLabel i dag) j
        in  psi j * Md.sgValue crf x
    | otherwise = \j ->
        let x = lbOn crf (DAG.edgeLabel i dag) j
        in  psi j * ((u - v x) + w x)
    where
      u = safeSum
        [ alpha' iMinus1 k
        | iMinus1 <- DAG.prevEdges i dag
        , (k, _) <- lbIxs crf dag iMinus1 ]
      v x = safeSum
        [ alpha' iMinus1 k
        | iMinus1 <- DAG.prevEdges i dag
        , (k, _) <- I.intersect (Md.prevIxs crf x) (lbVec crf dag iMinus1) ]
      w x = safeSum
        [ alpha' iMinus1 k * Md.valueL crf ix
        | iMinus1 <- DAG.prevEdges i dag
        , (k, ix) <- I.intersect (Md.prevIxs crf x) (lbVec crf dag iMinus1) ]
      -- Note that if `i == snd bounds` then `i` does not refer to any existing
      -- edge, hence the need to introduce `u'` which does almost the same thing
      -- as `u`.
      u' = safeSum
        [ alpha' iMinus1 k
        | iMinus1 <- DAG.dagEdges dag
        , DAG.isFinalEdge iMinus1 dag
        , (k, _) <- lbIxs crf dag iMinus1 ]


-- | Backward table computation.
backward :: Md.Model -> DAG a X -> ProbArray
backward crf dag = beta where
  beta = DP.flexible2 bounds boundsOn withMem
  bounds = (DAG.minEdge dag - 1, DAG.maxEdge dag)
  boundsOn i
    | i == fst bounds = (0, 0)
    | otherwise = (0, lbNum crf dag i - 1)
  psi = computePsi' crf dag
  -- set of final edges
  finalSet = S.fromList
    [ i
    | i <- DAG.dagEdges dag
    , DAG.isFinalEdge i dag ]
  withMem beta' i
    | i `S.member` finalSet = const 1
    | i == fst bounds = const $ safeSum
      [ beta' iPlus1 k * psi iPlus1 k
        * Md.sgValue crf (lbOn crf (DAG.edgeLabel iPlus1 dag) k)
      | iPlus1 <- DAG.dagEdges dag
      , DAG.isInitialEdge iPlus1 dag
      , (k, _) <- lbIxs crf dag iPlus1 ]
    | otherwise = \j ->
        let y = lbOn crf (DAG.edgeLabel i dag) j
        in  (u - v y) + w y
    where
      -- Note that here `i` is an identifier of the current DAG edge.
      -- Instead of simply adding `1` to `i` (i.e., `i + 1`),
      -- we need to find the identifiers of the succeeding edges.
      u = safeSum
        [ beta' iPlus1 k * psi iPlus1 k
        | iPlus1 <- DAG.nextEdges i dag
        , (k, _ ) <- lbIxs crf dag iPlus1 ]
      -- `y` is the label on position `i`, we are looking for
      -- matching labels on the position `i+1`.
      v y = safeSum
        [ beta' iPlus1 k * psi iPlus1 k
        | iPlus1 <- DAG.nextEdges i dag
        , (k, _ ) <- I.intersect (Md.nextIxs crf y) (lbVec crf dag iPlus1) ]
      -- `y` is the label on position `i`, we are looking for
      -- matching labels on the position `i+1`.
      w y = safeSum
        [ beta' iPlus1 k * psi iPlus1 k * Md.valueL crf ix
        | iPlus1 <- DAG.nextEdges i dag
        , (k, ix) <- I.intersect (Md.nextIxs crf y) (lbVec crf dag iPlus1) ]


-- | Normalization factor computed for the 'Xs' sentence using the
-- forward computation.
zx' :: Md.Model -> DAG a X -> L.LogFloat
zx' crf dag = zxAlpha dag (forward crf dag)

zxAlpha :: DAG a b -> ProbArray -> L.LogFloat
zxAlpha dag alpha = alpha (DAG.maxEdge dag + 1) 0


-- | Normalization factor computed for the 'Xs' sentence using the
-- backward computation.
zx :: Md.Model -> DAG a X -> L.LogFloat
zx crf dag = zxBeta dag (backward crf dag)

zxBeta :: DAG a b -> ProbArray -> L.LogFloat
zxBeta dag beta = beta (DAG.minEdge dag - 1) 0


-- prob1 :: ProbArray -> ProbArray -> Int -> LbIx -> L.LogFloat
-- prob1 alpha beta k x =
--     alpha k x * beta (k + 1) x / zxBeta beta
-- {-# INLINE prob1 #-}

-- | Probability of chosing the given edge and the corresponding label.
edgeProb1
  :: DAG a b
  -- ^ The underlying sentence DAG
  -> ProbArray
  -- ^ Forward probability table
  -> ProbArray
  -- ^ Backward probability table
  -> EdgeID
  -- ^ ID of the edge in the underlying DAG
  -> LbIx
  -- ^ Index of the label of the edge represented by the `EdgeID`
  -> L.LogFloat
edgeProb1 dag alpha beta k x
  -- alpha k x * beta k x / zxBeta dag beta
  | any isInf [up1, up2, down] =
      error $ "edgeProb1: infinite -- " ++ show [up1, up2, down, down'] -- ++ "; " ++ show (k, x)
  | otherwise = up1 * up2 / down
  where
    isInf v = isInfinite (L.logFromLogFloat v :: Double)
    up1 = alpha k x
    up2 = beta k x
    down = zxBeta dag beta
    down' = zxAlpha dag alpha
{-# INLINE edgeProb1 #-}


-- | Probability of chosing the given pair of edges and the corresponding labels.
edgeProb2
  :: Md.Model
  -- ^ CRF model
  -> DAG a b
  -- ^ The underlying sentence DAG
  -> ProbArray
  -- ^ Forward computation table
  -> ProbArray
  -- ^ Backward computation table
  -> (EdgeID -> LbIx -> L.LogFloat)
  -- ^ Psi computation
  -> (EdgeID, LbIx)
  -- ^ First edge and the corresponding label index
  -> (EdgeID, LbIx)
  -- ^ Succeeding edge and the corresponding label index
  -> Md.FeatIx
  -- ^ TODO (NO IDEA!); Hypo: index of the transition feature corresponding
  -- to the transition between the first and the succeeding edge
  -> L.LogFloat
edgeProb2 crf dag alpha beta psi (kEdgeID, xLbIx) (lEdgeID, yLbIx) ix
  -- = alpha kEdgeID xLbIx * beta lEdgeID yLbIx
  -- -- * psi lEdgeID yLbIx * Md.valueL crf ix / zxBeta dag beta
  | any isInf [up1, up2, up3, up4, down] =
      error $ "edgeProb2: infinite -- " ++ show [up1, up2, up3, up4, down, down']
  | otherwise = up1 * up2 * up3 * up4 / down
  where
    isInf x = isInfinite (L.logFromLogFloat x :: Double)
    up1 = alpha kEdgeID xLbIx
    up2 = beta lEdgeID yLbIx
    up3 = psi lEdgeID yLbIx
    up4 = Md.valueL crf ix
    down = zxBeta dag beta
    down' = zxAlpha dag alpha
{-# INLINE edgeProb2 #-}


-- prob2 :: Model -> ProbArray -> ProbArray -> Int -> (LbIx -> L.LogFloat)
--       -> LbIx -> LbIx -> FeatIx -> L.LogFloat
-- prob2 crf alpha beta k psi x y ix
--     = alpha (k - 1) y * beta (k + 1) x
--     * psi x * valueL crf ix / zxBeta beta
-- {-# INLINE prob2 #-}


-- | Tag potential labels with marginal distributions.
-- marginals :: Md.Model -> DAG a X -> [[(Lb, L.LogFloat)]]
marginals :: Md.Model -> DAG a X -> DAG a [(Lb, L.LogFloat)]
marginals crf dag
  | not (zx1 `almostEq` zx2) = trace warning margs
  | otherwise = margs
  where
    margs = DAG.mapE label dag
    warning =
      "[marginals] normalization factors differ significantly: "
      ++ show (L.logFromLogFloat zx1, L.logFromLogFloat zx2)
    label edgeID _ =
      [ (lab, prob1 edgeID labID)
      | (labID, lab) <- lbIxs crf dag edgeID ]
    prob1 = edgeProb1 dag alpha beta
    alpha = forward crf dag
    beta = backward crf dag
    zx1 = zxAlpha dag alpha
    zx2 = zxBeta dag beta


-- | Get (at most) k best tags for each word and return them in
-- descending order.  TODO: Tagging with respect to marginal
-- distributions might not be the best idea.  Think of some
-- more elegant method.
tagK :: Int -> Md.Model -> DAG a X -> DAG a [(Lb, L.LogFloat)]
tagK k crf dag = fmap
    ( take k
    . reverse
    . sortBy (compare `on` snd)
    ) (marginals crf dag)


-- | Find the most probable label sequence (with probabilities of individual
-- lables determined with respect to marginal distributions) satisfying the
-- constraints imposed over label values.
tag :: Md.Model -> DAG a X -> DAG a Lb
tag crf = fmap (fst . head) . (tagK 1 crf)


expectedFeaturesOn
  :: Md.Model
  -- ^ CRF model
  -> DAG a X
  -- ^ The underlying sentence DAG
  -> ProbArray
  -- ^ Forward computation table
  -> ProbArray
  -- ^ Backward computation table
  -> EdgeID
  -- ^ ID of an edge of the underlying DAG
  -> [(Md.FeatIx, L.LogFloat)]
expectedFeaturesOn crf dag alpha beta iEdgeID =
  tFeats ++ oFeats
  where
    prob1 = edgeProb1 dag alpha beta iEdgeID
    oFeats = [ (ix, prob1 k)
             | ob <- C.unX (DAG.edgeLabel iEdgeID dag)
             , (k, ix) <- I.intersect (Md.obIxs crf ob) (lbVec crf dag iEdgeID) ]

    -- TODO: Move `psi` to `expectedFeatureIn`
    psi = computePsi' crf dag -- iEdgeID
    prob2 = edgeProb2 crf dag alpha beta psi
    tFeats
        | DAG.isInitialEdge iEdgeID dag = catMaybes
          [ (, prob1 k) <$> Md.featToIx crf (C.SFeature x)
          | (k, x) <- lbIxs crf dag iEdgeID ]
        | otherwise =
          [ (ix, prob2 (iMinus1, l) (iEdgeID, k) ix)
          | (k,  x) <- lbIxs crf dag iEdgeID
          , iMinus1 <- DAG.prevEdges iEdgeID dag
          , (l, ix) <- I.intersect (Md.prevIxs crf x) (lbVec crf dag iMinus1) ]


-- | A list of features (represented by feature indices) defined within
-- the context of the sentence accompanied by expected probabilities
-- determined on the basis of the model.
--
-- One feature can occur multiple times in the output list.
expectedFeaturesIn
  :: Md.Model
  -> DAG a X
  -> [(Md.FeatIx, L.LogFloat)]
expectedFeaturesIn crf dag = zxF `par` zxB `pseq` zxF `pseq`
    -- concat [expectedOn k | k <- [0 .. V.length xs - 1] ]
    concat [expectedOn edgeID | edgeID <- DAG.dagEdges dag]
  where
    expectedOn = expectedFeaturesOn crf dag alpha beta
    alpha = forward crf dag
    beta = backward crf dag
    zxF = zxAlpha dag alpha
    zxB = zxBeta dag beta


-- goodAndBad :: Md.Model -> DAG a X -> DAG b Y -> (Int, Int)
goodAndBad :: Md.Model -> DAG a (X, Y) -> (Int, Int)
goodAndBad crf dag =
    F.foldl' gather (0, 0) $ DAG.zipE labels labels'
  where
    xs = fmap fst dag
    ys = fmap snd dag
    labels = fmap (best . C.unY) ys
    best zs
        | null zs   = Nothing
        | otherwise = Just . fst $ maximumBy (compare `on` snd) zs
    labels' = fmap Just $ tag crf xs
    gather (good, bad) (x, y)
        | x == y = (good + 1, bad)
        | otherwise = (good, bad + 1)


goodAndBad' :: Md.Model -> [DAG a (X, Y)] -> (Int, Int)
goodAndBad' crf dataset =
    let add (g, b) (g', b') = (g + g', b + b')
    in  F.foldl' add (0, 0) [goodAndBad crf x | x <- dataset]


-- | Compute the accuracy of the model with respect to the labeled dataset.
accuracy :: Md.Model -> [DAG a (X, Y)] -> Double
accuracy crf dataset =
    let k = numCapabilities
        parts = partition k dataset
        xs = parMap rseq (goodAndBad' crf) parts
        (good, bad) = F.foldl' add (0, 0) xs
        add (g, b) (g', b') = (g + g', b + b')
    in  fromIntegral good / fromIntegral (good + bad)


---------------------------------------------
-- Probability and likelihood
---------------------------------------------


-- -- | Log-likelihood of the given dataset.
-- likelihood :: Md.Model -> [DAG a (X, Y)] -> L.LogFloat
-- -- likelihood crf = L.product . map (probability crf)
-- -- likelihood crf = probability crf . head
-- likelihood crf = maximum . map (probability crf)
--
--
-- -- | The conditional probability of the dag in log-domain.
-- probability :: Md.Model -> DAG a (X, Y) -> L.LogFloat
-- probability crf dag = normFactor
-- --   | potential > normFactor =
-- --       error $ "[probability] potential greater than normFactor: "
-- --       ++ show (potential, normFactor)
-- --   | otherwise = potential / normFactor
--   where
--     potential = L.product
--       [ Md.valueL crf (Md.featToJustIx crf feat)
--       | (feat, _val) <- featuresIn dag ]
--     normFactor = zx crf (fmap fst dag)


-- -- | Features w.r.t. a given edge.
-- features
--   :: Md.Model
--   -> EdgeID  -- ^ ID of an edge of the DAG
--   -> DAG a (X, Y)
--   -> [Md.FeatIx]
-- features crf edgeID dag =
--   where
--     oFeats = [ (ix, prob1 k)
--              | ob <- C.unX (DAG.edgeLabel iEdgeID dag)
--              , (k, ix) <- I.intersect (Md.obIxs crf ob) (lbVec crf dag iEdgeID) ]


---------------------------------------------
-- Utils
---------------------------------------------


almostEq :: L.LogFloat -> L.LogFloat -> Bool
almostEq x0 y0
  | isZero x && isZero y = True
  | otherwise = 1.0 - eps < z && z < 1.0 + eps
  where
    x = L.logFromLogFloat x0
    y = L.logFromLogFloat y0
    z = x / y


isZero :: (Fractional t, Ord t) => t -> Bool
isZero x = abs x < eps


-- | A very small number.
eps :: Fractional t => t
eps = 0.000001