{-# LANGUAGE MagicHash #-}
-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.Factor
   ( -- * Factor type and dimensions
     Factor
   , Dimensions, dimensions, fromDimensions, hasVarD, mergesD
   , Dimensional(..)
     -- * Construction
   , makeFactor
     -- * Simple queries
   , values
     -- * Standard operations
   , multiply, sumout, sumouts, condition, conditions, normalize
     -- * Variable elimination on factors
   , eliminate, eliminateSplit, eliminateList
   ) where

import Data.List
import Data.Maybe
import Data.Semigroup
import GHC.Exts
import qualified Data.Set as S

data Factor = F { dimensions :: Dimensions, getTree :: Tree }

instance Show Factor where
   show (F d t) = intercalate "," (map showDim (fromDimensions d)) ++ ": " ++ show t
    where
      showDim (s, n) = s ++ "[" ++ show n ++ "]"

instance Semigroup Factor where
   (<>) = multiply

instance Monoid Factor where
   mempty  = makeFactor [] [1]
   mappend = (<>)
   mconcat = multiplyList

makeFactor :: [(String, Int)] -> [Double] -> Factor
makeFactor d bs = F (mkD dims) (reorder dims d $ makeTree (map snd d) bs)
 where
   dims = sort d

makeTree :: [Int] -> [Double] -> Tree
makeTree is = rec (reverse is) . map leaf
 where
   rec (2:ns) ts = rec ns $ groups2 ts
   rec (3:ns) ts = rec ns $ groups3 ts
   rec (n:ns) ts = rec ns $ groupsN n ts
   rec [] [t]    = t
   rec _ _       = error "makeFactor"

{-# INLINE groups2 #-}
groups2 :: [Tree] -> [Tree]
groups2 [] = []
groups2 (x:y:rest) = Two x y : groups2 rest
groups2 _ = error "groups2"

{-# INLINE groups3 #-}
groups3 :: [Tree] -> [Tree]
groups3 [] = []
groups3 (x:y:z:rest) = Three x y z : groups3 rest
groups3 _ = error "groups3"

{-# INLINE groupsN #-}
groupsN :: Int -> [Tree] -> [Tree]
groupsN n = rec
 where
   rec xs
      | null xs   = []
      | otherwise = Bin xs1 : rec xs2
    where
      (xs1, xs2) = splitAt n xs

condition :: String -> Int -> Factor -> Factor
condition s i (F d t) = F (deleteD s d) (set s i d t)

conditions :: [(String, Int)] -> Factor -> Factor
conditions xs (F d t) = F (deletesD (map fst xs) d) (choose xs d t)

multiplyList :: [Factor] -> Factor
multiplyList = rec
 where
   rec []  = mempty
   rec [a] = a
   rec [a,b] = multiply a b
   rec as  = rec (best:rest)
    where
      m:ms    = sortBy cmp as
      cmp x y = size x `compare` size y
      (best, rest) = minimumBy (\x y -> size (fst x) `compare` size (fst y)) $ map make [0 .. length ms-1]

      make i = (multiply m y, xs ++ ys)
       where
         (xs, y:ys) = splitAt i ms

multiply :: Factor -> Factor -> Factor
multiply (F d1 t1) (F d2 t2) = F (mergeD d1 d2) (mergeWith (*##) d1 d2 t1 t2)

sumouts :: [String] -> Factor -> Factor
sumouts vs0 (F dt t0) = F (deletesD vs0 dt) (rec (dropTail bools) t0)
 where
   ds    = map fst (fromDimensions dt)
   bools = f (filter (`elem` ds) vs0) ds

   f [] _  = []
   f vs (x:xs)
      | x `elem` vs = True  : f (delete x vs) xs
      | otherwise   = False : f vs xs
   f _ _ = []

   rec [] t = t
   rec (b:bs) t
      | b         = foldr1 (zipTree (+##)) (subtrees rt)
      | otherwise = rt
    where
      rt = mapSubtrees (rec bs) t

   dropTail = reverse . dropWhile not . reverse

sumout :: String -> Factor -> Factor
sumout = sumouts . return

values :: Factor -> [Double]
values = treeToList . getTree

normalize :: Factor -> Factor
normalize x =
   case sum (values x) of
      D# total -> x { getTree = mapTree (/## total) (getTree x) }

-----------------------------------------------------------------------------

eliminate :: [Factor] -> String -> [Factor]
eliminate fs = uncurry f . eliminateSplit fs
 where
   f x xs = if size x <= 1 then xs else x : xs

eliminateList :: [Factor] -> [String] -> [Factor]
eliminateList = foldl eliminate

-- | Eliminate variable: returns new factor and remaining (unchanged) factors
eliminateSplit :: [Factor] -> String -> (Factor, [Factor])
eliminateSplit fs s = (sumout s (mconcat fs1), fs2)
 where
   (fs1, fs2) = partition (S.member s . varSet) fs

-----------------------------------------------------------------------------

class Dimensional a where
   size   :: a -> Int
   varSet :: a -> S.Set String
   vars   :: a -> [String]
   -- default
   vars = S.toList . varSet

instance Dimensional Factor where
   size   = size . dimensions
   varSet = varSet . dimensions

instance Dimensional Dimensions where
   size (D n _) = n
   varSet = S.fromList . map fst . fromDimensions

instance Dimensional a => Dimensional [a] where
   size   = sum . map size
   varSet = S.unions . map varSet

---------------------------------------------------------

data Dimensions = D Int [(String, Int)]

fromDimensions :: Dimensions -> [(String, Int)]
fromDimensions (D _ xs) = xs

mkD :: [(String, Int)] -> Dimensions
mkD xs = D (product (map snd xs)) xs

deleteD :: String -> Dimensions -> Dimensions
deleteD s = filterD (/= s)

deletesD :: [String] -> Dimensions -> Dimensions
deletesD xs = filterD (`notElem` xs)

filterD :: (String -> Bool) -> Dimensions -> Dimensions
filterD p (D _ m) = mkD (filter (p . fst) m)

hasVarD :: String -> Dimensions -> Bool
hasVarD s (D _ xs) = rec xs
 where
   rec [] = False
   rec ((x, _):rest) =
      case compare s x of
         LT -> False
         EQ -> True
         GT -> rec rest

mergeD :: Dimensions -> Dimensions -> Dimensions
mergeD (D _ m1) (D _ m2) = mkD (rec m1 m2)
 where
   rec lx@(x:xs) ly@(y:ys) =
      case compare x y of
         LT -> x : rec xs ly
         EQ -> x : rec xs ys
         GT -> y : rec lx ys
   rec xs ys = xs ++ ys

mergesD :: [Dimensions] -> Dimensions
mergesD = foldr1 mergeD

-----------------------------------------------------------------------------

data Tree = Bin [Tree] | Leaf Double# | Two !Tree !Tree | Three !Tree !Tree !Tree

bin :: [Tree] -> Tree
bin [x, y] = Two x y
bin [x, y, z] = Three x y z
bin xs = Bin xs

leaf :: Double -> Tree
leaf (D# x) = Leaf x

instance Show Tree where
   show (Leaf a) = show (D# a)
   show t        = "(" ++ intercalate "," (map show (subtrees t)) ++ ")"

{-# INLINE mapTree #-}
mapTree :: (Double# -> Double#) -> Tree -> Tree
mapTree f = rec
 where
   rec (Leaf a) = Leaf (f a)
   rec t        = mapSubtrees rec t

treeToList :: Tree -> [Double]
treeToList = ($ []) .  rec
 where
   rec (Bin ts) = foldr ((.) . rec) id ts
   rec (Leaf a) = (D# a:)
   rec (Two x y) = rec x . rec y
   rec (Three x y z) = rec x . rec y . rec z

{-# INLINE zipTree #-}
zipTree :: (Double# -> Double# -> Double#) -> Tree -> Tree -> Tree
zipTree f = rec
 where
   rec (Leaf a) (Leaf b) = Leaf (f a b)
   rec t1 t2 = zipSubtrees rec t1 t2

{-# INLINE subtrees #-}
subtrees :: Tree -> [Tree]
subtrees (Bin xs)  = xs
subtrees (Two x y) = [x, y]
subtrees (Three x y z) = [x, y, z]
subtrees _ = error "subtrees"

{-# INLINE subtree #-}
subtree :: Int -> Tree -> Tree
subtree i (Bin xs) = xs !! i
subtree 0 (Two x _) = x
subtree 1 (Two _ y) = y
subtree 0 (Three x _ _) = x
subtree 1 (Three _ y _) = y
subtree 2 (Three _ _ z) = z
subtree _ _ = error "subtree"

{-# INLINE mapSubtrees #-}
mapSubtrees :: (Tree -> Tree) -> Tree -> Tree
mapSubtrees f (Bin xs)  = Bin (map f xs)
mapSubtrees f (Two x y) = Two (f x) (f y)
mapSubtrees f (Three x y z) = Three (f x) (f y) (f z)
mapSubtrees _ _ = error "mapSubtrees"

{-# INLINE zipSubtrees #-}
zipSubtrees :: (Tree -> Tree -> Tree) -> Tree -> Tree -> Tree
zipSubtrees f (Bin xs) (Bin ys) = Bin (zipWith f xs ys)
zipSubtrees f (Two x1 x2) (Two y1 y2) = Two (f x1 y1) (f x2 y2)
zipSubtrees f (Three x1 x2 x3) (Three y1 y2 y3) = Three (f x1 y1) (f x2 y2) (f x3 y3)
zipSubtrees _ _ _ = error "zipSubtrees"

{-# INLINE mergeWith #-}
mergeWith :: (Double# -> Double# -> Double#) -> Dimensions -> Dimensions -> Tree -> Tree -> Tree
mergeWith f da db = rec (merges (fromDimensions da) (fromDimensions db))
 where
   rec (TakeLeft:ms)  t1 t2 = mapSubtrees (\x -> rec ms x t2) t1
   rec (TakeRight:ms) t1 t2 = mapSubtrees (\x -> rec ms t1 x) t2
   rec (Merge:ms)     t1 t2 = zipSubtrees (rec ms) t1 t2
   rec []             t1 t2 = zipTree f t1 t2

   merges :: [(String, Int)] -> [(String, Int)] -> [Merge]
   merges [] [] = []
   merges (_:ds1) [] = TakeLeft  : merges ds1 []
   merges [] (_:ds2) = TakeRight : merges [] ds2
   merges l1@(d1:ds1) l2@(d2:ds2) =
      case compare d1 d2 of
         LT -> TakeLeft  : merges ds1 l2
         EQ -> Merge     : merges ds1 ds2
         GT -> TakeRight : merges l1 ds2

data Merge = TakeLeft | TakeRight | Merge

set :: String -> Int -> Dimensions -> Tree -> Tree
set s i = choose [(s, i)]

choose :: [(String, Int)] -> Dimensions -> Tree -> Tree
choose env d = rec (dropTail choices)
 where
   rec [] t = t
   rec (Nothing:cs) t = mapSubtrees (rec cs) t
   rec (Just i:cs)  t = rec cs (subtree i t)

   choices = map (\(s, _) -> lookup s env) (fromDimensions d)

   dropTail = reverse . dropWhile isNothing . reverse

setAtLevel :: Int -> Int -> Tree -> Tree
setAtLevel lev i = rec lev
 where
   rec 0 = subtree i
   rec l = mapSubtrees (rec (l-1))

reorder :: [(String, Int)] -> [(String, Int)] -> Tree -> Tree
reorder as bs t | as == bs = t
reorder (p:rest) (o:old) t | p==o = mapSubtrees (reorder rest old) t
reorder ((s, n):rest) old t =
   case findIndex ((== s) . fst) old of
      Just l  -> bin [ reorder rest (filter ((/= s) . fst) old) $ setAtLevel l i t | i <- [0..n-1] ]
      Nothing -> error "invalid reordering"
reorder _ _ _ = error "invalid reordering"