```{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Conditional probability table

Conditional Probability Tables and Probability tables

-}
module Bayes.Factor(
-- * Factor
Factor(..)
, isomorphicFactor
, normedFactor
-- * Set of variables
, Set(..)
, BayesianDiscreteVariable(..)
-- * Implementation
, Vertex(..)
-- ** Discrete variables and instantiations
, DV
--, DVSet(..)
, DVI
, setDVValue
, instantiationValue
, instantiationVariable
, variableVertex
, (=:)
, forAllInstantiations
, factorFromInstantiation
, changeVariableOrder
-- ** Factor
, CPT
-- * Tests
, testProductProject_prop
, testAssocProduct_prop
, testScale_prop
, testProjectCommut_prop
, testScalarProduct_prop
, testProjectionToScalar_prop
) where

import qualified Data.Vector.Unboxed as V
import Data.Vector.Unboxed((!))
import Data.Maybe(fromJust,mapMaybe,isJust)
import qualified Data.List as L
import Text.PrettyPrint.Boxes hiding((//))
import Test.QuickCheck
import Test.QuickCheck.Arbitrary
import qualified Data.IntMap as IM
import System.Random(Random)
import Data.List(partition)
import Bayes.PrivateTypes

--import Debug.Trace

--debug a = trace ("\nDEBUG\n" ++ show a ++ "\n") a

-- | A vertex associated to another value (variable dimension, variable value ...)
class LabeledVertex l where
variableVertex :: l -> Vertex

-- | Convert a variable instantation to a factor
-- Useful to create evidence factors
factorFromInstantiation :: Factor f => DVI Int -> f
factorFromInstantiation (DVI dv a) =
let setValue i = if i == a then 1.0 else 0.0
in
fromJust . factorWithVariables [dv] . map (setValue) \$ [0..dimension dv-1]

instance LabeledVertex (DVI a) where
variableVertex (DVI v _) = variableVertex v

instance LabeledVertex DV where
variableVertex (DV v _) = v

-- | Norm the factor
normedFactor :: Factor f => f -> f
normedFactor f = factorDivide f (factorNorm f)

-- | A factor as used in graphical model
-- It may or not be a probability distribution. So it has no reason to be
-- normalized to 1
class Factor f where
-- | When all variables of a factor have been summed out, we have a scalar
isScalarFactor :: f -> Bool
-- | An empty factor with no variable and no values
emptyFactor :: f
-- | Check if a given discrete variable is contained in a factor
containsVariable :: f -> DV  -> Bool
-- | Give the set of discrete variables used by the factor
factorVariables :: f -> [DV]
-- | Return A in P(A | C D ...). It is making sense only if the factor is a conditional propbability
-- table. It must always be in the vertex corresponding to A in the bayesian graph
factorMainVariable :: f -> DV
factorMainVariable f = let vars = factorVariables f
in
case vars of
[] -> error "Can't get the main variable of a scalar factor"
(h:_) -> h
-- | Create a new factors with given set of variables and a list of value
-- for initialization. The creation may fail if the number of values is not
-- coherent with the variables and their levels.
-- For boolean variables ABC, the value must be given in order
-- FFF, FFT, FTF, FTT ...
factorWithVariables :: [DV] -> [Double] -> Maybe f
-- | Value of factor for a given set of variable instantitation.
-- The variable instantion is like a multi-dimensional index.
factorValue :: f -> [DVI Int] -> Double
-- | Position of a discrete variable in te factor (p(AB) is differennt from p(BA) since values
-- are not organized in same order in memory)
variablePosition :: f -> DV -> Maybe Int
-- | Dimension of the factor (number of floating point values)
factorDimension :: f -> Int

-- | Norm of the factor = sum of its values
factorNorm :: f -> Double

-- | Scale the factor values by a given scaling factor
factorScale :: Double -> f -> f

-- | Create a scalar factor with no variables
factorFromScalar :: Double -> f

-- | Create an evidence factor from an instantiation.
-- If the instantiation is empty then we get nothing
evidenceFrom :: [DVI Int] -> Maybe f

-- | Divide all the factor values
factorDivide :: f -> Double -> f
factorDivide f d = (1.0 / d) `factorScale` f

factorToList :: f -> [Double]

-- | Multiply factors.
factorProduct :: [f] -> f

-- | Project out a factor. The variable in the DVSet are summed out
factorProjectOut :: [DV] -> f -> f

-- | Project to. The variable are kept and other variables are removed
factorProjectTo :: [DV] -> f -> f
factorProjectTo s f =
let alls = factorVariables f
s' = alls `difference` s
in
factorProjectOut s' f

-- | Change the layout of values in the
-- factor to correspond to a new variable order
-- Used to import external files
changeVariableOrder :: DVSet s -- ^ Old order
-> DVSet s' -- ^ New order
-> [Double] -- ^ Old values
-> [Double] -- ^ New values
changeVariableOrder (DVSet oldOrder) newOrder oldValues =
let oldFactor = fromJust \$ factorWithVariables oldOrder oldValues :: CPT
in
[factorValue oldFactor i | i <- forAllInstantiations newOrder]

-- | Mainly used for conditional probability table like p(A B | C D E) but the normalization to 1
-- is not imposed. And the conditionned variables are not different from the conditionning ones.
-- The dimensions for each variables are listed.
-- The variables on the left or right of the condition bar are not tracked. What's matter is that
-- it is encoding a function of several variables.
-- Marginalization of variables will be computed from the bayesian graph where
-- the knowledge of the dependencies is.
-- So, this same structure is used for a probability too (conditioned on nothing)
data CPT = CPT {
dimensions :: ![DV] -- ^ Dimensions for all variables
, mapping :: !(IM.IntMap Int) -- ^ Mapping from vertex number to position in dimensions
, values :: !(V.Vector Double) -- ^ Table of values
}
| Scalar !Double

debugCPT (Scalar d) = do
putStrLn "SCALAR CPT"
print d
putStrLn ""

debugCPT (CPT d m v) = do
putStrLn "CPT"
print d
putStrLn ""
print m
putStrLn ""
print v
putStrLn ""
{-

CPT can't have same same vertex values but with different sizes.
But, arbitrary CPT generation will general several vertex with same vertex id
and different vertex size.

So, we introduce a function mapping a vertex ID to a vertex size. So, vertex size are hard coded

-}

quickCheckVertexSize :: Int -> Int
quickCheckVertexSize 0 = 2
quickCheckVertexSize 1 = 2
quickCheckVertexSize 2 = 2
quickCheckVertexSize _ = 2

-- | Generate a random value until this value is not already present in the list
whileIn :: (Arbitrary a, Eq a) => [a] -> Gen a -> Gen a
whileIn l m = do
newVal <- m
if newVal `elem` l
then
whileIn l m
else
return newVal

-- | Generate a random vector of n elements without replacement (no duplicate)
-- May loop if the range is too small !
generateWithoutReplacement :: (Random a, Arbitrary a, Eq a)
=> Int -- ^ Vector size
-> (a,a) -- ^ Bounds
-> Gen [a]
generateWithoutReplacement n b | n == 1 = generateSingle b
| n > 1 = generateMultiple n b
| otherwise = return []
where
generateSingle b = do
r <- choose b
return [r]
generateMultiple n b = do
l <- generateWithoutReplacement (n-1) b
newelem <- whileIn l \$ choose b
return (newelem:l)

instance Arbitrary CPT where
arbitrary = do
nbVertex <- choose (1,4) :: Gen Int
vertexNumbers <- generateWithoutReplacement nbVertex (0,50)
let dimensions = map (\i -> (DV (Vertex i)  (quickCheckVertexSize i))) vertexNumbers
let valuelen = product (map dimension dimensions)
rndValues <- vectorOf valuelen (choose (0.0,1.0) :: Gen Double)
return . fromJust . factorWithVariables dimensions \$ rndValues

-- | Test product followed by a projection when the factors have no
-- common variables

-- | Floating point number comparisons which should take into account
-- all the subtleties of that kind of comparison
nearlyEqual :: Double -> Double -> Bool
nearlyEqual a b =
let absA = abs a
absB = abs b
diff = abs (a-b)
epsilon = 2e-5
in
case (a,b) of
(x,y) | x == y -> True -- handle infinities
| x*y == 0 -> diff < (epsilon * epsilon)
| otherwise -> diff / (absA + absB) < epsilon

testScale_prop :: Double -> CPT -> Bool
testScale_prop s f = (factorNorm (s `factorScale` f)) `nearlyEqual` (s * (factorNorm f))

testProductProject_prop :: CPT -> CPT -> Property
testProductProject_prop fa fb = isEmpty ((factorVariables fa) `intersection` (factorVariables fb))  ==>
let r = factorProjectOut (factorVariables fb) (factorProduct [fa,fb])
fa' = r `factorDivide` (factorNorm fb)
in
fa' `isomorphicFactor` fa

testScalarProduct_prop :: Double -> CPT -> Bool
testScalarProduct_prop v f = (factorProduct [(Scalar v),f]) `isomorphicFactor` (v `factorScale` f)

testAssocProduct_prop :: CPT -> CPT -> CPT -> Bool
testAssocProduct_prop a b c = (factorProduct [factorProduct [a,b],c] `isomorphicFactor` factorProduct [a,factorProduct [b,c]]) &&
(factorProduct [a,b,c] `isomorphicFactor` (factorProduct [factorProduct [a,b],c]) )

testProjectionToScalar_prop :: CPT -> Bool
testProjectionToScalar_prop f =
let allVars = factorVariables f
in
(factorProjectOut allVars f) `isomorphicFactor` (factorFromScalar (factorNorm f))

testProjectCommut_prop:: CPT -> Property
testProjectCommut_prop f = length (factorVariables f) >= 3 ==>
let a = take 1 (factorVariables f)
b = take 1 . drop 1 \$ factorVariables f
commuta = factorProjectOut a (factorProjectOut b f)
commutb = factorProjectOut b (factorProjectOut a f)
in
commuta `isomorphicFactor` commutb

-- | Test equality of two factors taking into account the fact
-- that the variables may be in a different order.
-- In case there is a distinction between conditionned variable and
-- conditionning variables (imposed from the exterior) then this
-- comparison may not make sense. It is a comparison of
-- function of several variables which no special interpretation of the
-- meaning of the variables according to their position.
isomorphicFactor :: Factor f => f -> f -> Bool
isomorphicFactor fa fb = maybe False (const True) \$ do
let sa = factorVariables fa
sb = factorVariables fb
va = DVSet sa
vb = DVSet sb
guard (sa `equal` sb)
guard (factorDimension fa == factorDimension fb)
guard \$ and [factorValue fa ia `nearlyEqual` factorValue fb ia | ia <- forAllInstantiations va]
return ()

{-

Following functions are used to typeset the factor when displaying it

-}
-- | Display a variable name and its size
vname :: Int -> DVI Int -> Box
vname vc i = text \$ "v" ++ show vc ++ "=" ++ show (instantiationValue i)

dispFactor :: Factor f => f -> DV -> [DVI Int] -> [DV] -> Box
dispFactor cpt h c [] =
let dstIndexes = allInstantiationsForOneVariable h
dependentIndexes =  reverse c
factorValueAtPosition p =
let v = factorValue cpt p
in
text . show  \$ v
in
vsep 0 center1 . map (factorValueAtPosition . (:dependentIndexes)) \$ dstIndexes

dispFactor cpt dst c (h@(DV (Vertex vc) i):l) =
let allInst = allInstantiationsForOneVariable h
in
hsep 1 top . map (\i -> vcat center1 [vname vc i,dispFactor cpt dst (i:c) l])  \$ allInst

instance Show CPT where
show (Scalar v) = "\nScalar Factor:\n" ++ show v
show c@(CPT [] _ v) = "\nEmpty CPT:\n"

show c@(CPT d _ v) =
let h@(DV (Vertex vc) _) = head d
table = dispFactor c h [] (tail d)
dstIndexes = map head (forAllInstantiations . DVSet \$ [h])
-- In P(A | B ...), the dst column is containing the possible values for the
-- variables A with a header made of space to be aligned with the other part of the table.
-- In the other part of the table, this header is containing the variable values for the other varibles
dstColumn = vcat center1 \$ replicate (length d - 1) (text "") ++ map (vname vc) dstIndexes
in
"\n" ++ show d ++ "\n" ++ render (hsep 1 top [dstColumn,table])

instance Factor CPT where
factorToList (Scalar v) = [v]
factorToList (CPT _ _ v) = V.toList v
emptyFactor = emptyCPT
isScalarFactor (Scalar _) = True
isScalarFactor _ = False
factorFromScalar v = Scalar v
factorDimension f@(CPT _ _ _) = product . map dimension . factorVariables\$ f
factorDimension _ = 1
containsVariable (CPT _ m _) (DV (Vertex i) _)   = IM.member i m
containsVariable (Scalar _) _ = False
factorWithVariables = createCPTWithDims
factorVariables (CPT v _ _) = v
factorVariables (Scalar _) = []
factorNorm f@(CPT d _ vals) =
let vars = DVSet d
strides = indexStrides vars
in
sum [ vals!(indexPosition strides x) | x <- indicesForDomain vars]
factorNorm (Scalar v) = v
variablePosition (CPT _ m _) (DV (Vertex i) _) = IM.lookup i m
variablePosition (Scalar _) _ = Nothing
factorScale s (Scalar v) = Scalar (s*v)
factorScale s f@(CPT d _ vals) =
let vars = DVSet d
strides = indexStrides vars
newValues = map (s *) [ vals!(indexPosition strides x) | x <- indicesForDomain vars]
in
fromJust \$ factorWithVariables (factorVariables f) newValues
factorValue (Scalar v) _ = v
factorValue f@(CPT d _ v) i =
let vars = DVSet d
(dv,pos) = instantiationDetails i
strides = indexStridesFor vars dv
in
v!(indexPosition strides pos)
evidenceFrom [] = Nothing
evidenceFrom l =
let (variables,index) = instantiationDetails l
DVSet nakedVars = variables
setValueForIndex i = if i == index then 1.0 else 0.0
in
factorWithVariables nakedVars . map setValueForIndex \$ indicesForDomain variables
factorProduct [] = factorFromScalar 1.0
factorProduct l =
let allVars = DVSet \$ L.foldl1' union . map factorVariables \$ l
DVSet nakedVars = allVars
(scalars,cpts) = partition isScalarFactor l
stridesFromCPT (CPT d _ _) = indexStridesFor (DVSet d) allVars
ps = product . map (flip factorValue undefined) \$ scalars
cptsStrides = map stridesFromCPT cpts
in
if L.null cpts
then
factorFromScalar ps
else
let getFactorValueAtIndex i (strides,factor@(CPT _ _ vals)) = vals!(indexPosition strides i)
instantiationProduct instantiation = product . map (getFactorValueAtIndex instantiation) \$ (zip cptsStrides cpts)
values = [ps * instantiationProduct x | x <- indicesForDomain allVars]
in
values `seq` fromJust \$ factorWithVariables nakedVars values
factorProjectOut _ f@(Scalar v) = f
factorProjectOut s f@(CPT d _ v) =
let variablesToSum = s
variablesToKeep = d `difference` s
keepSet = DVSet variablesToKeep
sumSet = DVSet variablesToSum
strides = indexStridesFor (DVSet d) (DVSet \$ variablesToKeep ++ variablesToSum)

values = do
keepIndex <- indicesForDomain keepSet
let l = do
sumIndex <- indicesForDomain sumSet
return \$ v!(indexPosition strides \$ combineIndex strides keepIndex sumIndex)
return (sum l)
in
values `seq` fromJust \$ factorWithVariables variablesToKeep values

-- | Used to combined the keep and sum indices in the factor projection
combineIndex :: Strides s'' -> [Index s] -> [Index s'] -> [Index s'']
combineIndex _ la lb = map (Index . fromIndex) la ++ map (Index .fromIndex) lb

-- | An empty CPT
emptyCPT :: CPT
emptyCPT = CPT [] IM.empty V.empty

newtype Strides s = Strides [Int] deriving(Eq,Show)

-- | Generate strides to read the first CPT using an index having meaning in the second CPT
indexStridesFor :: DVSet s -- ^ DVSet to be read
-> DVSet s' -- ^ DVSet to interpret the index
-> Strides s'
indexStridesFor dr@(DVSet drvars) di@(DVSet divars) =
let Strides originStrides = indexStrides dr
reference = zip drvars originStrides
getNewStrides dv = maybe 0 id (lookup dv reference)
in
Strides \$ map getNewStrides divars

-- | Generate the strides to read a given factor using a multiindex
-- using the same order as the factor variables
indexStrides :: DVSet s -> Strides s
indexStrides d@(DVSet dvars)  =
let dim = map dimension dvars
pos' = scanr (*) (1::Int) (tail dim)
in
Strides pos'
-- | Convertion of a multiindex to its
-- position inside of the data vector of a 'CPT'
indexPosition :: Strides s -> [Index s] -> Int
{-# INLINE indexPosition #-}
indexPositions _ []  = 0
indexPosition (Strides pos') pos = sum . map (\(x,y) -> x * fromIndex y) \$ (zip pos' pos)

-- | Create a CPT given some dimensions and a list of Doubles.
-- Returns nothing is the length are not coherents.
createCPTWithDims :: [DV] -> [Double] -> Maybe CPT
createCPTWithDims dims values =
let createDVIndex i (DV (Vertex v) _)  = (v,i)
m = IM.fromList . zipWith createDVIndex ([0,1..]::[Int]) \$ dims
p = product (map dimension dims)
in
if length values == p
then
Just \$ CPT dims m (V.fromList values)
else
Nothing

```