{-# LANGUAGE ConstraintKinds  #-}
{-# LANGUAGE FlexibleContexts #-}

module Data.Array.Accelerate.KullbackLiebler ( kullbackLiebler
                                             , entropy
                                             , dropZeroes
                                             , scale
                                             , hellinger
                                             , fDivergence
                                             , alphaDivergence
                                             ) where

import qualified Data.Array.Accelerate as A

-- | \( D_f(p \| q) = \displaystyle\int p(x) f\left(\frac{p(x)}{q(x)}\right) dx \)
--
-- @since 0.1.2.0
fDivergence :: (A.Floating e) => (A.Exp e -> A.Exp e) -- ^ \(f\)
            -> A.Acc (A.Vector e)
            -> A.Acc (A.Vector e)
            -> A.Acc (A.Scalar e)
fDivergence :: (Exp e -> Exp e)
-> Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
fDivergence Exp e -> Exp e
f Acc (Vector e)
ps Acc (Vector e)
qs = Acc (Vector e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum ((Exp e -> Exp e -> Exp e)
-> Acc (Vector e) -> Acc (Vector e) -> Acc (Vector e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith (\Exp e
p Exp e
q -> Exp e
p Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e -> Exp e
f (Exp e
p Exp e -> Exp e -> Exp e
forall a. Fractional a => a -> a -> a
/ Exp e
q)) Acc (Vector e)
ps Acc (Vector e)
qs)

-- | \( D^{(\alpha)}(p\| q) = \frac{4}{1 - \alpha^2}\left(1 - \displaystyle\int p(x)^{\frac{1-\alpha}{2}} q(x)^{\frac{1+\alpha}{2}} dx\right)\)
-- for \( \alpha \neq \pm 1\)
--
-- @since 0.1.2.0
alphaDivergence :: A.Floating e => A.Exp e -> A.Acc (A.Vector e) -> A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
alphaDivergence :: Exp e -> Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
alphaDivergence Exp e
α Acc (Vector e)
ps Acc (Vector e)
qs = (Exp e -> Exp e) -> Acc (Scalar e) -> Acc (Scalar e)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map (\Exp e
x -> (Exp e
4 Exp e -> Exp e -> Exp e
forall a. Fractional a => a -> a -> a
/ (Exp e
1 Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
- Exp e
α Exp e -> Exp e -> Exp e
forall a. Floating a => a -> a -> a
** Exp e
2)) Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* (Exp e
1 Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
- Exp e
x)) Acc (Scalar e)
integrand
    where integrand :: Acc (Scalar e)
integrand = Acc (Vector e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum ((Exp e -> Exp e -> Exp e)
-> Acc (Vector e) -> Acc (Vector e) -> Acc (Vector e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith (\Exp e
p Exp e
q -> Exp e
p Exp e -> Exp e -> Exp e
forall a. Floating a => a -> a -> a
** ((Exp e
1 Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
- Exp e
α)Exp e -> Exp e -> Exp e
forall a. Fractional a => a -> a -> a
/Exp e
2) Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e
q Exp e -> Exp e -> Exp e
forall a. Floating a => a -> a -> a
** ((Exp e
1 Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
+ Exp e
α)Exp e -> Exp e -> Exp e
forall a. Fractional a => a -> a -> a
/Exp e
2)) Acc (Vector e)
ps Acc (Vector e)
qs)

-- | Hellinger distance
--
-- @since 0.1.2.0
hellinger :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
hellinger :: Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
hellinger Acc (Vector e)
ps Acc (Vector e)
qs = (Exp e -> Exp e) -> Acc (Scalar e) -> Acc (Scalar e)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map (Exp e -> Exp e
forall a. Floating a => a -> a
A.sqrt (Exp e -> Exp e) -> (Exp e -> Exp e) -> Exp e -> Exp e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp e
2Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
*)) (Acc (Scalar e) -> Acc (Scalar e))
-> Acc (Scalar e) -> Acc (Scalar e)
forall a b. (a -> b) -> a -> b
$ Acc (Vector e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum ((Exp e -> Exp e -> Exp e)
-> Acc (Vector e) -> Acc (Vector e) -> Acc (Vector e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith (\Exp e
p Exp e
q -> (Exp e -> Exp e
forall a. Floating a => a -> a
A.sqrt Exp e
p Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
- Exp e -> Exp e
forall a. Floating a => a -> a
A.sqrt Exp e
q) Exp e -> Exp e -> Exp e
forall a. Floating a => a -> a -> a
** Exp e
2) Acc (Vector e)
ps Acc (Vector e)
qs)

-- | Assumes input is nonzero
kullbackLiebler :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
kullbackLiebler :: Acc (Vector e) -> Acc (Vector e) -> Acc (Scalar e)
kullbackLiebler Acc (Vector e)
ps Acc (Vector e)
qs = Acc (Vector e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum ((Exp e -> Exp e -> Exp e)
-> Acc (Vector e) -> Acc (Vector e) -> Acc (Vector e)
forall sh a b c.
(Shape sh, Elt a, Elt b, Elt c) =>
(Exp a -> Exp b -> Exp c)
-> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c)
A.zipWith (\Exp e
p Exp e
q -> Exp e
p Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e -> Exp e
forall a. Floating a => a -> a
log (Exp e
p Exp e -> Exp e -> Exp e
forall a. Fractional a => a -> a -> a
/ Exp e
q)) Acc (Vector e)
ps Acc (Vector e)
qs)

-- | Assumes input is nonzero
entropy :: (A.Floating e) => A.Acc (A.Vector e) -> A.Acc (A.Scalar e)
entropy :: Acc (Vector e) -> Acc (Scalar e)
entropy = Acc (Vector e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum (Acc (Vector e) -> Acc (Scalar e))
-> (Acc (Vector e) -> Acc (Vector e))
-> Acc (Vector e)
-> Acc (Scalar e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp e -> Exp e) -> Acc (Vector e) -> Acc (Vector e)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map (\Exp e
p -> Exp e
p Exp e -> Exp e -> Exp e
forall a. Num a => a -> a -> a
* Exp e -> Exp e
forall a. Floating a => a -> a
log Exp e
p)

dropZeroes :: (A.Eq e, Num (A.Exp e)) => A.Acc (A.Vector e) -> A.Acc (A.Vector e)
dropZeroes :: Acc (Vector e) -> Acc (Vector e)
dropZeroes = Acc (Vector e, Array DIM0 Int) -> Acc (Vector e)
forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc a
A.afst (Acc (Vector e, Array DIM0 Int) -> Acc (Vector e))
-> (Acc (Vector e) -> Acc (Vector e, Array DIM0 Int))
-> Acc (Vector e)
-> Acc (Vector e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp e -> Exp Bool)
-> Acc (Vector e) -> Acc (Vector e, Array DIM0 Int)
forall sh e.
(Shape sh, Elt e) =>
(Exp e -> Exp Bool)
-> Acc (Array (sh :. Int) e) -> Acc (Vector e, Array sh Int)
A.filter (Exp e -> Exp e -> Exp Bool
forall a. Eq a => Exp a -> Exp a -> Exp Bool
A./= Exp e
0)

-- | Doesn't check for negative values
--
-- @since 0.1.1.0
scale :: A.Floating e => A.Acc (A.Vector e) -> A.Acc (A.Vector e)
scale :: Acc (Vector e) -> Acc (Vector e)
scale Acc (Vector e)
xs =
    let tot :: Exp e
tot = Acc (Scalar e) -> Exp e
forall e. Elt e => Acc (Scalar e) -> Exp e
A.the (Acc (Scalar e) -> Exp e) -> Acc (Scalar e) -> Exp e
forall a b. (a -> b) -> a -> b
$ Acc (Vector e) -> Acc (Scalar e)
forall sh e.
(Shape sh, Num e) =>
Acc (Array (sh :. Int) e) -> Acc (Array sh e)
A.sum Acc (Vector e)
xs
    in (Exp e -> Exp e) -> Acc (Vector e) -> Acc (Vector e)
forall sh a b.
(Shape sh, Elt a, Elt b) =>
(Exp a -> Exp b) -> Acc (Array sh a) -> Acc (Array sh b)
A.map (Exp e -> Exp e -> Exp e
forall a. Fractional a => a -> a -> a
/Exp e
tot) Acc (Vector e)
xs