{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}

module Data.Array.Accelerate.KullbackLiebler ( kullbackLiebler
                                             , entropy
                                             , dropZeroes
                                             , scale
                                             ) where

import qualified Data.Array.Accelerate as A

-- | Assumes input is nonzero
kullbackLiebler :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
kullbackLiebler ps qs = A.sum (A.zipWith (\p q -> p * log (p / q)) ps qs)

-- | Assumes input is nonzero
entropy :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
entropy = A.sum . A.map (\p -> p * log p)

dropZeroes :: (A.Eq e, Num (A.Exp e)) => A.Acc (A.Vector e) -> A.Acc (A.Vector e)
dropZeroes = A.afst . A.filter (A./= 0)

-- | Doesn't check for negative values
--
-- @since 0.1.1.0
scale :: A.Floating e => A.Acc (A.Vector e) -> A.Acc (A.Vector e)
scale xs =
    let tot = A.the $ A.sum xs
    in A.map (/tot) xs