{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{- | Conditional probability table

Conditional Probability Tables and Probability tables

-}
module Bayes.Factor(
 -- * Factor
   Factor(..)
 , isomorphicFactor
 , normedFactor
 -- * Set of variables 
 , Set(..)
 , BayesianDiscreteVariable(..)
 -- * Implementation
 , Vertex(..)
 -- ** Discrete variables and instantiations
 , DV(..)
 , DVSet(..)
 , DVI
 , DVISet(..)
 , setDVValue
 , instantiationValue
 , instantiationVariable
 , variableVertex
 , (=:)
 , forAllInstantiations
 , factorFromInstantiation
 , changeVariableOrder
 -- ** Factor
 , CPT
 -- * Tests
 , testProductProject_prop
 , testScale_prop
 , testProjectCommut_prop
 , testScalarProduct_prop
 , testProjectionToScalar_prop
 ) where

import qualified Data.Vector.Unboxed as V
import Data.Vector.Unboxed((!))
import Data.Maybe(fromJust,mapMaybe)
import qualified Data.List as L
import Text.PrettyPrint.Boxes hiding((//))
import Test.QuickCheck
import Test.QuickCheck.Arbitrary
import qualified Data.IntMap as IM
import Control.Monad
import System.Random(Random)

--import Debug.Trace

--debug a = trace ("\nDEBUG\n" ++ show a ++ "\n") a

-- | Vertex type used to identify a vertex in a graph
newtype Vertex = Vertex {vertexId :: Int} deriving(Eq,Ord)

instance Show Vertex where 
    show (Vertex v) = "v" ++ show v

-- | A Set of variables used in a factor. s is the set and a the variable
class Set s where
    -- | Empty set
    emptySet :: s a
    -- | Union of two sets
    union :: Eq a => s a -> s a -> s a
    -- | Intersection of two sets
    intersection :: Eq a => s a -> s a -> s a
    -- | Difference of two sets
    difference :: Eq a => s a -> s a -> s a
    -- | Check if the set is empty
    isEmpty :: s a -> Bool
    -- | Check if an element is member of the set
    isElem :: Eq a => a -> s a -> Bool
    -- | Add an element to the set
    addElem :: Eq a => a -> s a -> s a
    -- | Number of elements in the set
    nbElements :: s a -> Int

    -- | Check if a set is subset of another one
    subset :: Eq a => s a -> s a -> Bool

    -- | Check set equality
    equal :: Eq a => s a -> s a -> Bool
    equal sa sb = (sa `subset` sb) && (sb `subset` sa)

instance Set [] where
    emptySet = []
    union = L.union
    intersection = L.intersect
    difference a b = a L.\\ b
    isEmpty [] = True 
    isEmpty _ = False
    isElem = L.elem 
    addElem a l = if a `elem` l then l else a:l
    nbElements = length
    subset sa sb = all (`elem` sb) sa

-- | A discrete variable has a number of levels which is required to size the factors
class BayesianDiscreteVariable v where
    dimension :: v -> Int 


-- | A vertex associated to another value (variable dimension, variable value ...)
class LabeledVertex l where
    variableVertex :: l -> Vertex

-- | A discrete variable
data DV = DV !Vertex !Int deriving(Eq,Ord)

-- | A set of discrete variables
type DVSet = [DV]

instance Show DV where
    show (DV v d) = show v ++ "(" ++ show d ++ ")"

-- | Discrete Variable instantiation. A variable and its value
data DVI a = DVI DV !a deriving(Eq)

instance Show a => Show (DVI a) where 
   show (DVI (DV v _) i) = show v ++ "=" ++ show i

-- | Convert a variable instantation to a factor
-- Useful to create evidence factors
factorFromInstantiation :: Factor f => DVI Int -> f
factorFromInstantiation (DVI dv a) = 
    let setValue i = if i == a then 1.0 else 0.0 
    in
    fromJust . factorWithVariables [dv] . map (setValue) $ [0..dimension dv-1]

-- | A set of variable instantiations
type DVISet a = [DVI a]

instance BayesianDiscreteVariable DV where
    dimension (DV _ d) = d

-- | Create a discrete variable instantiation for a given discrete variable
setDVValue :: DV -> a -> DVI a
setDVValue v a = DVI v a

getMinBound :: Bounded a => a -> a 
getMinBound _ = minBound

-- | Create a variable instantiation using values from
-- an enumeration
(=:) :: (Bounded b, Enum b) => DV -> b -> DVI Int 
(=:) a b = setDVValue a (fromEnum b - fromEnum (getMinBound b))

-- | Extract value of the instantiation
instantiationValue (DVI _ v) = v

-- | Discrete variable from the instantiation
instantiationVariable (DVI dv _) = dv

instance LabeledVertex (DVI a) where
    variableVertex (DVI v _) = variableVertex v

instance LabeledVertex DV where
    variableVertex (DV v _) = v

-- | Extend indexing to full variable set using a bool
-- list and a default value
-- For instance [True, False, True, False] 5 [2,3] ---> [2,5,3,5]
extend :: [Bool] -> a -> [a] -> [a]
extend [] _ l = l
extend (h:t) d [] = d:extend t d []
extend (False:t) d l = d:extend t d l
extend (True:t) d (h:l') = h:extend t d l'

-- | Inner loop function using full indices for full variables
type InnerLoop a = [Int] -> a

-- | Outer loop function using result from inner loop
-- and outer vars indices
type OuterLoop a b = [Int] -> [a] -> b

-- | Iter on outer var and inner var
-- Inner body is called with indiced for full set
-- Outer body is called with indices for outer set
forSubA :: DVSet -- ^ All variables
        -> DVSet -- ^ Outer variables
        -> (DVSet -> [Int] -> [a]) -- ^ Inner loop body
        -> OuterLoop a b -- ^ Outer loop function
        -> [b]
forSubA allvars outervars inner outer = 
    let sCode s e = if (e `isElem` s) then True else False
        selection s = map (sCode s) allvars
        computeOuter iouter =
            let outerIdx =  extend (selection outervars) 0 iouter
                innerValues = inner allvars outerIdx
            in 
            outer iouter innerValues
    in
    map computeOuter (forAllIndices outervars)

-- | Use indices with full variable set
forSubB :: DVSet -- ^ Inner vars 
        -> InnerLoop a -- ^ Inner loop function
        -> DVSet -- ^ All vars
        -> [Int] -- ^ Outer indices
        -> [a]
forSubB innervars f allvars  outerIdx  = 
        let sCode s e = if (e `isElem` s) then True else False
            selection s = map (sCode s) allvars
            computeInner iinner =
                let innerIdx = extend (selection innervars) 0 iinner
                    idx = zipWith (+) outerIdx innerIdx
                    in 
                    f idx
        in
        map computeInner (forAllIndices innervars)

-- | Norm the factor
normedFactor :: Factor f => f -> f 
normedFactor f = factorDivide f (factorNorm f)

-- | A factor as used in graphical model
-- It may or not be a probability distribution. So it has no reason to be
-- normalized to 1
class FactorPrivate f => Factor f where
    -- | When all variables of a factor have been summed out, we have a scalar
    isScalarFactor :: f -> Bool 
    -- | An empty factor with no variable and no values
    emptyFactor :: f
    -- | Check if a given discrete variable is contained in a factor
    containsVariable :: f -> DV  -> Bool
    -- | Give the set of discrete variables used by the factor
    factorVariables :: f -> DVSet    
    -- | Return A in P(A | C D ...). It is making sense only if the factor is a conditional propbability
    -- table. It must always be in the vertex corresponding to A in the bayesian graph
    factorMainVariable :: f -> DV
    factorMainVariable = head . factorVariables
    -- | Create a new factors with given set of variables and a list of value
    -- for initialization. The creation may fail if the number of values is not
    -- coherent with the variables and their levels.
    -- For boolean variables ABC, the value must be given in order
    -- FFF, FFT, FTF, FTT ...
    factorWithVariables :: DVSet -> [Double] -> Maybe f
    -- | Value of factor for a given set of variable instantitation.
    -- The variable instantion is like a multi-dimensional index.
    factorValue :: f -> DVISet Int -> Double
    -- | Position of a discrete variable in te factor (p(AB) is differennt from p(BA) since values
    -- are not organized in same order in memory)
    variablePosition :: f -> DV -> Maybe Int
    -- | Dimension of the factor (number of floating point values)
    factorDimension :: f -> Int
    
    -- | Norm of the factor = sum of its values
    factorNorm :: f -> Double 
    

    -- | Scale the factor values by a given scaling factor
    factorScale :: Double -> f -> f

    -- | Create a scalar factor with no variables
    factorFromScalar :: Double -> f

    -- | Create an evidence factor from an instantiation.
    -- If the instantiation is empty then we get nothing
    evidenceFrom :: DVISet Int -> Maybe f
    

    -- | Divide all the factor values
    factorDivide :: f -> Double -> f
    factorDivide f d = (1.0 / d) `factorScale` f 

    -- | Multiply factors. 
    factorProduct :: [f] -> f
    factorProduct [] = factorFromScalar 1.0
    factorProduct l = 
        let allVars = L.foldl1' union . map factorVariables $ l
        in 
        if L.null allVars 
            then 
                factorFromScalar (product . map factorNorm $ l) 
            else
                let getFactorValueAtIndex i factor = factorValuePrivate factor (reorder i factor)
                    instantiationProduct instantiation = product . map (getFactorValueAtIndex instantiation) $ l
                    values = [instantiationProduct x | x <- forAllInstantiations allVars]
                in 
                fromJust $ factorWithVariables allVars values

    -- | Project out a factor. The variable in the DVSet are summed out
    factorProjectOut :: DVSet -> f -> f
    factorProjectOut s f = 
        let alls = factorVariables f
            s' = alls `difference` s
        in 
        if null s'
            then 
                factorFromScalar (factorNorm f)
            else
                let dstValues = forSubA alls s' 
                                   (forSubB s $ factorValuePrivate f)
                                   (\i c -> sum c)
                in 
                fromJust $ factorWithVariables s' dstValues
    -- | Project to. The variable are kept and other variables are removed
    factorProjectTo :: DVSet -> f -> f 
    factorProjectTo s f = 
        let alls = factorVariables f 
            s' = alls `difference` s 
        in 
        factorProjectOut s' f

-- | Used internaly when we know the position of a variable in the factor
-- then we can identify the variable with an int. May be a bit faster for some
-- algorithms
class FactorPrivate f where
    factorValuePrivate :: f -> [Int] -> Double

-- | Return all the index (position in the factor) for a DV
allValues :: DV -> [Int]
allValues (DV _ i) = [0..i-1]

-- | Generate all indexes for a set of variables
forAllIndices :: DVSet -> [[Int]]
forAllIndices = mapM allValues

-- | Generate all instantiations of variables
forAllInstantiations :: DVSet -> [DVISet Int]
forAllInstantiations = mapM oneInstantiation
 where
    oneInstantiation v@(DV vertex _) = map (setDVValue v) . allValues $ v

-- | Change the layout of values in the
-- factor to correspond to a new variable order
changeVariableOrder :: DVSet -- ^ Old order
                    -> DVSet -- ^ New order 
                    -> [Double] -- ^ Old values
                    -> [Double] -- ^ New values
changeVariableOrder oldOrder newOrder oldValues =
    let oldFactor = fromJust $ factorWithVariables oldOrder oldValues :: CPT
    in
    [factorValue oldFactor i | i <- forAllInstantiations newOrder]


-- | Order the variable to get a multiindex which is
-- making sense in the CPT. Only the subset in CPT is selectionned and reordered
reorder :: Factor f => DVISet Int -> f  -> [Int]
reorder i f = 
    let nbDestVars = nbElements . factorVariables $ f
        v = V.replicate nbDestVars 0
        asDV v = DV v 0
        vectorPair bdvi = do 
            pos <- variablePosition f . asDV . variableVertex $ bdvi
            let value = instantiationValue bdvi
            return (pos, value)
        allPos = mapMaybe vectorPair i
    in
    let testError = maybe False (const True) $ do 
        guard $ length allPos == nbDestVars
        guard $ and . map ( (< nbDestVars) . fst)  $ allPos
        return ()
    in
    case testError of
      False -> error ("reorder has not set all destination indexes ! allpos = " ++ show allPos ++ " nbDestVars = " ++ show nbDestVars ++ "\n" ) 
      True -> V.toList $ v V.// allPos


-- | Mainly used for conditional probability table like p(A B | C D E) but the normalization to 1
-- is not imposed. And the conditionned variables are not different from the conditionning ones.
-- The dimensions for each variables are listed.
-- The variables on the left or right of the condition bar are not tracked. What's matter is that
-- it is encoding a function of several variables.
-- Marginalization of variables will be computed from the bayesian graph where
-- the knowledge of the dependencies is.
-- So, this same structure is used for a probability too (conditioned on nothing)
data CPT = CPT {
           dimensions :: DVSet -- ^ Dimensions for all variables
         , mapping :: IM.IntMap Int -- ^ Mapping from vertex number to position in dimensions
         , values :: V.Vector Double -- ^ Table of values
         }
         | Scalar Double

debugCPT (Scalar d) = do 
   putStrLn "SCALAR CPT"
   print d
   putStrLn ""

debugCPT (CPT d m v) = do 
    putStrLn "CPT"
    print d 
    putStrLn ""
    print m 
    putStrLn ""
    print v
    putStrLn ""
{-

CPT can't have same same vertex values but with different sizes.
But, arbitrary CPT generation will general several vertex with same vertex id
and different vertex size.

So, we introduce a function mapping a vertex ID to a vertex size. So, vertex size are hard coded

-}

quickCheckVertexSize :: Int -> Int
quickCheckVertexSize 0 = 2
quickCheckVertexSize 1 = 2
quickCheckVertexSize 2 = 2
quickCheckVertexSize _ = 2

-- | Generate a random value until this value is not already present in the list
whileIn :: (Arbitrary a, Eq a) => [a] -> Gen a -> Gen a
whileIn l m = do 
    newVal <- m 
    if newVal `elem` l 
        then
            whileIn l m 
        else 
            return newVal

-- | Generate a random vector of n elements without replacement (no duplicate)
-- May loop if the range is too small !
generateWithoutReplacement :: (Random a, Arbitrary a, Eq a)  
                           => Int -- ^ Vector size
                           -> (a,a) -- ^ Bounds
                           -> Gen [a]
generateWithoutReplacement n b | n == 1 = generateSingle b 
                               | n > 1 = generateMultiple n b 
                               | otherwise = return []
 where
   generateSingle b = do 
       r <- choose b
       return [r]
   generateMultiple n b = do 
       l <- generateWithoutReplacement (n-1) b
       newelem <- whileIn l $ choose b
       return (newelem:l)



instance Arbitrary CPT where
    arbitrary = do 
        nbVertex <- choose (1,4) :: Gen Int
        vertexNumbers <- generateWithoutReplacement nbVertex (0,50)
        let dimensions = map (\i -> (DV (Vertex i)  (quickCheckVertexSize i))) vertexNumbers
        let valuelen = product (map dimension dimensions)
        rndValues <- vectorOf valuelen (choose (0.0,1.0) :: Gen Double)
        return . fromJust . factorWithVariables dimensions $ rndValues

-- | Test product followed by a projection when the factors have no
-- common variables

-- | Floating point number comparisons which should take into account
-- all the subtleties of that kind of comparison
nearlyEqual :: Double -> Double -> Bool
nearlyEqual a b = 
    let absA = abs a 
        absB = abs b 
        diff = abs (a-b)
        epsilon = 2e-5
    in
    case (a,b) of 
        (x,y) | x == y -> True -- handle infinities
              | x*y == 0 -> diff < (epsilon * epsilon)
              | otherwise -> diff / (absA + absB) < epsilon

testScale_prop :: Double -> CPT -> Bool
testScale_prop s f = (factorNorm (s `factorScale` f)) `nearlyEqual` (s * (factorNorm f))

testProductProject_prop :: CPT -> CPT -> Property
testProductProject_prop fa fb = isEmpty ((factorVariables fa) `intersection` (factorVariables fb))  ==> 
    let r = factorProjectOut (factorVariables fb) (factorProduct [fa,fb])
        fa' = r `factorDivide` (factorNorm fb)
    in
    fa' `isomorphicFactor` fa

testScalarProduct_prop :: Double -> CPT -> Bool 
testScalarProduct_prop v f = (factorProduct [(Scalar v),f]) `isomorphicFactor` (v `factorScale` f)

testProjectionToScalar_prop :: CPT -> Bool 
testProjectionToScalar_prop f = 
    let allVars = factorVariables f 
    in
    (factorProjectOut allVars f) `isomorphicFactor` (factorFromScalar (factorNorm f))

testProjectCommut_prop:: CPT -> Property 
testProjectCommut_prop f = length (factorVariables f) >= 3 ==>
    let a = take 1 (factorVariables f)
        b = take 1 . drop 1 $ factorVariables f 
        commuta = factorProjectOut a (factorProjectOut b f)
        commutb = factorProjectOut b (factorProjectOut a f)
    in
    commuta `isomorphicFactor` commutb

-- | Test equality of two factors taking into account the fact
-- that the variables may be in a different order.
-- In case there is a distinction between conditionned variable and
-- conditionning variables (imposed from the exterior) then this
-- comparison may not make sense. It is a comparison of
-- function of several variables which no special interpretation of the
-- meaning of the variables according to their position.
isomorphicFactor :: Factor f => f -> f -> Bool
isomorphicFactor fa fb = maybe False (const True) $ do 
    let va = factorVariables fa 
        vb = factorVariables fb 
    guard (va `equal` vb)
    guard (factorDimension fa == factorDimension fb)
    guard $ and [factorValue fa ia `nearlyEqual` factorValue fb ia | ia <- forAllInstantiations va]
    return ()

{-

Following functions are used to typeset the factor when displaying it

-}
vname :: Int -> Int -> Box
vname vc i = text $ "v" ++ show vc ++ "=" ++ show i

dispFactor :: FactorPrivate f => f -> DV -> [Int] -> DVSet -> Box
dispFactor cpt h c [] = 
    let dstIndexes = allValues h
        dependentIndexes =  reverse c
        factorValueAtPosition p = 
            let v = factorValuePrivate cpt p
            in
            text . show  $ v
    in
    vsep 0 center1 . map (factorValueAtPosition . (:dependentIndexes)) $ dstIndexes

dispFactor cpt dst c (h@(DV (Vertex vc) i):l) = 
    hsep 1 top . map (\i -> vcat center1 [vname vc i,dispFactor cpt dst (i:c) l])  $ (allValues h)

instance Show CPT where
    show (Scalar v) = "\nScalar Factor:\n" ++ show v
    show c@(CPT [] _ v) = "\nEmpty CPT:\n"

    show c@(CPT d _ v) = 
        let h@(DV (Vertex vc) _) = head d
            table = dispFactor c h [] (tail d)
            dstColumn = vcat center1 $ replicate (length d - 1) (text "") ++ map (vname vc) (allValues h)
        in
        "\n" ++ show d ++ "\n" ++ render (hsep 1 top [dstColumn,table])

instance Factor CPT where
    emptyFactor = emptyCPT
    isScalarFactor (Scalar _) = True
    isScalarFactor _ = False
    factorFromScalar v = Scalar v
    factorDimension f@(CPT _ _ _) = product . map dimension . factorVariables$ f
    factorDimension _ = 1
    containsVariable (CPT _ m _) (DV (Vertex i) _)   = IM.member i m
    containsVariable (Scalar _) _ = False
    factorWithVariables = createCPTWithDims
    factorVariables (CPT v _ _) = v
    factorVariables (Scalar _) = []
    factorNorm f@(CPT _ _ _) = sum [ factorValuePrivate f x | x <- forAllIndices (factorVariables f)]
    factorNorm (Scalar v) = v
    variablePosition (CPT _ m _) (DV (Vertex i) _) = IM.lookup i m
    variablePosition (Scalar _) _ = Nothing
    factorScale s (Scalar v) = Scalar (s*v)
    factorScale s f = 
        let newValues = map (s *) [ factorValuePrivate f x | x <- forAllIndices (factorVariables f)]
        in 
        fromJust $ factorWithVariables (factorVariables f) newValues
    factorValue (Scalar v) _ = v 
    factorValue f i = 
        let multiIndex = reorder i f
        in 
        factorValuePrivate f multiIndex
    evidenceFrom [] = Nothing 
    evidenceFrom l = 
        let index = map instantiationValue l 
            variables = map instantiationVariable l
            setValueForIndex i = if i == index then 1.0 else 0.0 
        in
        factorWithVariables variables . map setValueForIndex $ forAllIndices variables

instance FactorPrivate CPT where
    factorValuePrivate = getCPTValue


-- | An empty CPT
emptyCPT :: CPT
emptyCPT = CPT [] IM.empty V.empty

-- | Convertion of a multiindex to its
-- position inside of the data vector of a 'CPT'
indexPosition :: DVSet -> [Int] -> Int
indexPosition [] _ = 0
indexPosition d pos = 
    let dim = map dimension d
        pos' = scanr (*) (1::Int) (tail dim)
        c = sum . map (\(x,y) -> x * y) $ (zip pos' pos)
    in 
    c

-- | Get the value at a given position. Positions are starting at zero
getCPTValue :: CPT -> [Int] -> Double
getCPTValue (Scalar v) _ = v
getCPTValue cpt@(CPT d _ v) pos = v!(indexPosition d pos)

-- | Create a CPT given some dimensions and a list of Doubles.
-- Returns nothing is the length are not coherents.
createCPTWithDims :: DVSet -> [Double] -> Maybe CPT
createCPTWithDims dims values = 
    let createDVIndex i (DV (Vertex v) _)  = (v,i)
        m = IM.fromList . zipWith createDVIndex ([0,1..]::[Int]) $ dims
        p = product (map dimension dims) 
    in
    if length values == p
        then
            Just $ CPT dims m (V.fromList values)
        else 
            Nothing