{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}

module Data.Array.Accelerate.KullbackLiebler ( kullbackLiebler
                                             , entropy
                                             , dropZeroes
                                             ) where

import qualified Data.Array.Accelerate as A

-- | Assumes input is nonzero
kullbackLiebler :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
kullbackLiebler ps qs = A.sum (A.zipWith (\p q -> p * log (p / q)) ps qs)

-- | Assumes input is nonzero
entropy :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
entropy = A.sum . A.map (\p -> p * log p)

-- | Drops zero values (for parity with scipy @entropy@ function)
dropZeroes :: (A.Eq e, Num (A.Exp e)) => A.Acc (A.Vector e) -> A.Acc (A.Vector e)
dropZeroes = A.afst . A.filter (A./= 0)