{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{- | Private functions used to implement CPT

-}
module Bayes.Factor.PrivateCPT(
   -- * Functions
     cptFactorProjectOutWith
   -- ** Private functions
   , FactorElement(..)
   , DecisionFactor(..)
   , _emptyFactor
   , _factorVariables
   , _isScalarFactor
   , _containsVariable
   , _factorDimension
   , _variablePosition
   , _factorFromScalar
   , _factorWithVariables
   , _factorToList
   , _factorNorm
   , _factorScale 
   , _factorValue
   , _evidenceFrom
   , _factorProduct
   , _isUsingSameVariablesAs
   , CPT
   , MAXCPT
   , PrivateCPT(..)
   , PossibleInstantiations
   , convertToMaxFactor
   , convertToNormalFactor 
   , decisionFactor
   , debugCPT
   , privateFactorValue
  ) where

import qualified Data.Vector.Generic as GV

import qualified Data.Vector.Unboxed as V
import Bayes.PrivateTypes
import qualified Data.IntMap as IM
import Bayes.Factor
import Data.Vector.Generic((!))
import Data.Maybe(fromJust,mapMaybe,isJust)
import qualified Data.List as L
import Data.List(partition)
import qualified Data.Vector as NV


{-
 
CPT used for normal CPT and MAXCPT

-}
-- | Mainly used for conditional probability table like p(A B | C D E) but the normalization to 1
-- is not imposed. And the conditionned variables are not different from the conditionning ones.
-- The dimensions for each variables are listed.
-- The variables on the left or right of the condition bar are not tracked. What's matter is that
-- it is encoding a function of several variables.
-- Marginalization of variables will be computed from the bayesian graph where
-- the knowledge of the dependencies is.
-- So, this same structure is used for a probability too (conditioned on nothing)
data PrivateCPT v a = Table {
                      dimensions :: ![DV] -- ^ Dimensions for all variables
                    , mapping :: !(IM.IntMap Int) -- ^ Mapping from vertex number to position in dimensions
                    , values :: !(v a) -- ^ Table of values
                    }
                    | Scalar !a  deriving(Eq)

debugCPT (Scalar d) = do 
   putStrLn "SCALAR CPT"
   print d
   putStrLn ""

debugCPT (Table d m v) = do 
    putStrLn "CPT"
    print d 
    putStrLn ""
    print m 
    putStrLn ""
    print v
    putStrLn ""

-- | Convert a factor to MAXCPT to change the behavior of inference algorithms and use max
convertToMaxFactor :: CPT -> MAXCPT
convertToMaxFactor (Scalar v) = Scalar (v,[])
convertToMaxFactor (Table d m v) =
  let newValues = NV.fromList . map (\x -> (x,[])) . V.toList $ v
  in
  Table d m newValues

type DecisionFactor = PrivateCPT (NV.Vector) DVI 




-- | Return an array of instantiations. Used to compute decisions
-- in an influence diagram
decisionFactor :: MAXCPT -> DecisionFactor
decisionFactor (Scalar (_,l)) = Scalar (head . head $ l) 
decisionFactor (Table d m v) =
  let extractElem = head . head -- We only have ONE instantiated variable in influence diagram. So no need
      -- for a list of list (list of several variables)
      newValues = NV.fromList . map extractElem . map snd . NV.toList $ v
  in
  Table d m newValues



convertToNormalFactor :: MAXCPT -> CPT 
convertToNormalFactor (Scalar a) = Scalar (fst a)
convertToNormalFactor (Table d m v) =
  let newValues = V.fromList . map fst . NV.toList $ v
  in
  Table d m newValues

type CPT = PrivateCPT (V.Vector) Double

type PossibleInstantiations = [DVISet]

type MAXCPT = PrivateCPT (NV.Vector) (Double,PossibleInstantiations)

class FactorElement a where 
    doubleValue :: a -> Double 
    mkValue :: Double -> a
    scale :: Double -> a -> a
    multiply :: a -> a -> a
    divide :: a -> a -> a
    elementSum ::  a -> a -> a
    elementUnit :: a

_isUsingSameVariablesAs :: PrivateCPT v a -> PrivateCPT v a -> Bool
_isUsingSameVariablesAs (Table va _ _) (Table vb _ _) = va == vb

_emptyFactor :: GV.Vector v a => PrivateCPT v a
_emptyFactor = Table [] IM.empty GV.empty

_factorVariables :: PrivateCPT v a -> [DV]
_factorVariables (Table v _ _) = v
_factorVariables (Scalar _) = []

_factorToList :: (FactorElement a, GV.Vector v a) => PrivateCPT v a -> [Double]
_factorToList (Scalar v) = [doubleValue v]
_factorToList (Table _ _ v) = map doubleValue $ GV.toList v

_isScalarFactor :: PrivateCPT v a -> Bool 
_isScalarFactor (Scalar _) = True
_isScalarFactor _ = False

_factorFromScalar :: FactorElement a => Double -> PrivateCPT v a
_factorFromScalar v = Scalar (mkValue v) 

_factorDimension :: Factor (PrivateCPT v a) => PrivateCPT v a -> Int
_factorDimension f@(Table _ _ _) = product . map dimension . factorVariables$ f
_factorDimension _ = 1

_containsVariable :: PrivateCPT v a -> DV  -> Bool
_containsVariable (Table _ m _) (DV (Vertex i) _)   = IM.member i m
_containsVariable (Scalar _) _ = False

_factorWithVariables :: (FactorElement a, GV.Vector v a) => [DV] -> [Double] -> Maybe (PrivateCPT v a)
_factorWithVariables dv v = createCPTWithDims dv (GV.fromList (map mkValue v))

_factorNorm :: (FactorElement a, GV.Vector v a) => PrivateCPT v a -> Double 
_factorNorm f@(Table d _ vals) = 
    let vars = DVSet d
        strides = indexStrides vars
    in
    sum [ doubleValue(vals!(indexPosition strides x)) | x <- indicesForDomain vars]
_factorNorm (Scalar v) = (doubleValue v)


_variablePosition :: PrivateCPT v a -> DV -> Maybe Int
_variablePosition (Table _ m _) (DV (Vertex i) _) = let r = IM.lookup i m in r `seq` r
_variablePosition (Scalar _) _ = Nothing

_factorScale :: (FactorElement a, GV.Vector v a, Factor (PrivateCPT v a)) => Double -> PrivateCPT v a -> PrivateCPT v a
_factorScale s (Scalar v) = Scalar (scale s v)
_factorScale s f@(Table d _ vals) = 
    let vars = DVSet d
        strides = indexStrides vars
        newValues = GV.force . GV.fromList $ map (scale s) [ vals!(indexPosition strides x) | x <- indicesForDomain vars]
    in 
    fromJust $ createCPTWithDims (factorVariables f) newValues

privateFactorValue :: (GV.Vector v a) => PrivateCPT v a  -> [DVI] -> a
{-# INLINE privateFactorValue #-}
privateFactorValue (Scalar v) _ = v 
privateFactorValue f@(Table d _ v) i = 
    let vars = DVSet d
        (dv,pos) = instantiationDetails i
        strides = indexStridesFor vars dv
    in 
    v!(indexPosition strides pos)

_evidenceFrom :: (FactorElement a, GV.Vector v a) => [DVI] -> Maybe (PrivateCPT v a)
_evidenceFrom [] = Nothing 
_evidenceFrom l = 
    let (variables,index) = instantiationDetails l
        DVSet nakedVars = variables
        setValueForIndex i = if i == index then mkValue 1.0 else mkValue 0.0 
        values = GV.force . GV.fromList $ map setValueForIndex $ indicesForDomain variables
    in
    createCPTWithDims nakedVars values

_factorValue :: (FactorElement a, GV.Vector v a) => PrivateCPT v a  -> [DVI] -> Double
{-# INLINE _factorValue #-}
_factorValue f d = doubleValue (privateFactorValue f d)

_factorProduct :: (FactorElement a, GV.Vector v a, Factor (PrivateCPT v a)) 
               => (a -> a -> a)
               -> [PrivateCPT v a] 
               -> PrivateCPT v a
_factorProduct _ [] = factorFromScalar 1.0
_factorProduct op l = 
    let nakedVars = L.foldr1 union . map factorVariables $ l
        allVars = DVSet nakedVars
        (scalars,cpts) = partition isScalarFactor l
        stridesFromCPT (Table d _ _) = indexStridesFor (DVSet d) allVars
        elementProduct [] = elementUnit
        elementProduct l = foldr1 op l
        ps = elementProduct . map (flip privateFactorValue undefined) $ scalars
        cptsStrides = map stridesFromCPT cpts
    in 
    if L.null cpts 
        then 
            Scalar ps
        else
            let getFactorValueAtIndex i (strides,factor@(Table _ _ vals)) = vals!(indexPosition strides i)
                instantiationProduct instantiation = elementProduct . map (getFactorValueAtIndex instantiation) $ (zip cptsStrides cpts)
                values = GV.force $ GV.fromList [op ps (instantiationProduct x) | x <- indicesForDomain allVars]
            in 
            fromJust $ createCPTWithDims nakedVars values


-- | Genereic projection function which can use any function to project out
-- some variables in a CPT. Used to implement the MAP queries.
cptFactorProjectOutWith :: (FactorElement a, GV.Vector v a)
                        => ([(a,DVISet)] -> a) -- ^ Summing or maximizing function
                        -> [DV] -- ^ Variables to sum or maximize
                        -> PrivateCPT v a 
                        -> PrivateCPT v a
cptFactorProjectOutWith sumF s f@(Table d _ v) = 
            let variablesToSum = s
                variablesToKeep = d `difference` s 
                keepSet = DVSet variablesToKeep
                sumSet = DVSet variablesToSum 
                
                values = 
                    let strides = indexStridesFor (DVSet d) (DVSet $ variablesToKeep ++ variablesToSum)
                    in
                    GV.force . GV.fromList $ do 
                      keepIndex <- indicesForDomain keepSet 
                      let l = do
                            sumIndex <- indicesForDomain sumSet 
                            let result = v!(indexPosition strides $ combineIndex strides keepIndex sumIndex) 
                                sumInstantiation = instantiation sumSet sumIndex
                            return (result,sumInstantiation)
                      return $ (sumF l)
            in
            fromJust $ createCPTWithDims variablesToKeep values

-- | Create a CPT given some dimensions and a list of Doubles.
-- Returns nothing is the length are not coherents.
createCPTWithDims :: GV.Vector v a => [DV] -> v a -> Maybe (PrivateCPT v a)
createCPTWithDims [] v = Just (Scalar (GV.head v))
createCPTWithDims dims values = 
    let createDVIndex i (DV (Vertex v) _)  = (v,i)
        m = IM.fromList . zipWith createDVIndex ([0,1..]::[Int]) $ dims
        p = product (map dimension dims) 
    in
    if GV.length values == p
        then
            Just $! m `seq` Table dims m (GV.force values)
        else 
            Nothing



{-

Stride management

-}


-- | Used to combined the keep and sum indices in the factor projection
combineIndex :: Strides s'' -> MultiIndex s -> MultiIndex s' -> MultiIndex s''
{-# INLINE combineIndex #-}
combineIndex _ (MultiIndex la) (MultiIndex lb) = MultiIndex (la V.++ lb)



newtype Strides s = Strides (V.Vector Int) deriving(Eq,Show)

-- | Generate strides to read the first CPT using an index having meaning in the second CPT
indexStridesFor :: DVSet s -- ^ DVSet to be read
                -> DVSet s' -- ^ DVSet to interpret the index
                -> Strides s'
indexStridesFor dr@(DVSet drvars) di@(DVSet divars) =
    let Strides originStrides = indexStrides dr
        reference = zip drvars (V.toList originStrides) 
        getNewStrides dv = maybe 0 id (lookup dv reference)
    in 
    Strides $ (V.force . V.fromList $ map getNewStrides divars)
    

-- | Generate the strides to read a given factor using a multiindex
-- using the same order as the factor variables
indexStrides :: DVSet s -> Strides s
indexStrides d@(DVSet dvars)  = 
    let dim = map dimension dvars
        pos' = scanr (*) (1::Int) (tail dim)
    in 
    Strides (V.force . V.fromList $ pos')

-- | Convertion of a multiindex to its
-- position inside of the data vector of a 'CPT'
indexPosition :: Strides s -> MultiIndex s -> Int
{-# INLINE indexPosition #-}
indexPositions _ []  = 0
indexPosition (Strides pos') (MultiIndex pos) = V.sum . V.zipWith (\x y -> x * y) pos' $ pos