-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-----------------------------------------------------------------------------

module Bayes.NodeTypes
   ( virtualFactor, mkNoisyOr, mkNoisyAdder, mkNoisyAdder2
   ) where

import Bayes.Factor
import Bayes.Probability
import Data.List
import Data.Semigroup

virtualFactor :: String -> [Probability] -> Factor
virtualFactor s ps = makeFactor [(s, length ps), ('#':s, 2)] (concatMap f ps)
 where
   f x = map toDouble [x, 1-x]

mkNoisyOr :: (String, Int) -> [(String, Int)] -> [Int] -> [Probability] -> Factor
mkNoisyOr target vs str cs = sumouts (map f (reverse vs) ++ ["L"]) (mconcat (mkVars target vs str cs ++ [orGate target vs]))
 where
   f (s, _) = 'X':s

mkVars :: (a, Int) -> [(String, Int)] -> [Int] -> [Probability] -> [Factor]
mkVars target [] _ cs = [makeFactor [("L", snd target)] (map toDouble cs)]
mkVars target (v:vs) str cs = mkVar v (snd target) (shuffleGroups (snd target) str1 cs1) : mkVars target vs str2 cs2
 where
   (cs1, cs2)   = splitAt (snd target * snd v) cs
   (str1, str2) = splitAt (snd v) str

mkVar :: (String, Int) -> Int -> [Probability] -> Factor
mkVar (s, n) nx = makeFactor [(s, n), ('X':s, nx)] . map toDouble

orGate :: (String, Int) -> [(String, Int)] -> Factor
orGate target vs = makeFactor ns cs
 where
   ns = map h vs ++ [("L", snd target), target]
   cs = cartesian g (map minimum (combine (map f (map h vs ++ [target])))) (f target)
   f (_, n) = take n [0::Int ..]
   g x y = if x == y then 1 else 0
   h (s, _) = ('X':s, snd target)

combine :: [[a]] -> [[a]]
combine = foldr (cartesian (:)) [[]]

cartesian :: (a -> b -> c) -> [a] -> [b] -> [c]
cartesian f as bs = [ f a b | a <- as, b <- bs ]

shuffleGroups :: Int -> [Int] -> [a] -> [a]
shuffleGroups n is = concat . shuffle is . groups n

shuffle :: [Int] -> [a] -> [a]
shuffle is as = map (as !!) is

groups :: Int -> [a] -> [[a]] -- copied from factor
groups n xs
   | null xs   = []
   | otherwise = xs1 : groups n xs2
 where
   (xs1, xs2) = splitAt n xs

--------------------------------------------------------------------------------

mkNoisyAdder :: (String, Int) -> [(String, Int)] -> [Int] -> [Double] -> [Probability] -> Factor
mkNoisyAdder target vs dst ws params =
   mkNoisyAdder2 target (sort tups) (last ws, last pss)
 where
   -- dimensions for parents
   (ss, ns) = unzip vs
   -- parameters for each parent (and lastly, leak)
   pss = splitParameters ns params
   -- 5-tuples for each parent (var, dimension, distinguished node, weight, parameters)
   tups = zip5 ss ns (tail dst) (init ws) (init pss)

   splitParameters :: [Int] -> [Probability] -> [[Probability]]
   splitParameters [] cs = [cs]
   splitParameters (y:ys) cs =
      let (cs1, cs2) = splitAt (y * snd target) cs
      in cs1 : splitParameters ys cs2

mkNoisyAdder2 :: (String, Int) -> [(String, Int, Int, Double, [Probability])] -> (Double, [Probability]) -> Factor
mkNoisyAdder2 target@(_, tn) tups (lw, lps) =
   makeFactor (vs ++ [target]) $  map fromAvg result
 where
   vs     = map (\(s, n, _, _, _) -> (s, n)) tups
   result = snd (foldr add start tups)
   start  = (1, zipWith avg lps (replicate tn lw))

   -- keep cumulative sizes for aligning tables in result (first component)
   add :: (String, Int, Int, Double, [Probability]) -> (Int, [Avg]) -> (Int, [Avg])
   add (_, n, d, w, fc) (cum, ps) = (cum*n, zipWith (<>) (alignTable cum tn qs) (concat (replicate n ps)))
    where
      qs = zipWith avg fc (mkWeights n d w)

   mkWeights :: Int -> Int -> Double -> [Double]
   mkWeights n d w = [ if x==d then 0 else w | x <- [0..n-1], _ <- [1..tn] ]

alignTable :: Int -> Int -> [a] -> [a]
alignTable a b xs = concat $ concatMap (replicate a) (groups b xs)

-- weighted average
data Avg = Avg !Double !Double -- sum of weighted values and sum of weights 

instance Show Avg where
   show = show . fromAvg

instance Semigroup Avg where
   Avg a1 w1 <> Avg a2 w2 = Avg (a1+a2) (w1+w2)

instance Monoid Avg where
   mempty  = Avg 0 0
   mappend = (<>)

avg :: Probability -> Double -> Avg
avg x y = Avg (toDouble x*y) y

fromAvg :: Avg -> Double
fromAvg (Avg x y) = x / y