{-# LANGUAGE UndecidableInstances,UndecidableSuperClasses #-}

-- | Tools for modelling the differential and Riemannian geometry of a
-- 'Manifold'.
module Goal.Geometry.Differential
    ( -- * Riemannian Manifolds
      Riemannian (metric, flat, sharp)
    , euclideanDistance
    -- * Backpropagation
    , Propagate (propagate)
    , backpropagation
    -- * Legendre Manifolds
    , PotentialCoordinates
    , Legendre (potential)
    , DuallyFlat (dualPotential)
    , canonicalDivergence
    -- * Automatic Differentiation
    , differential
    , hessian
    ) where


--- Imports ---


-- Goal --

import Goal.Core

import qualified Goal.Core.Vector.Storable as S
import qualified Goal.Core.Vector.Boxed as B
import qualified Goal.Core.Vector.Generic as G

import Goal.Geometry.Manifold
import Goal.Geometry.Vector
import Goal.Geometry.Map
import Goal.Geometry.Map.Linear

-- Qualified --

import qualified Numeric.AD as D


-- | Computes the differential of a function of the coordinates at a point using
-- automatic differentiation.
differential
    :: Manifold x
    => (forall a. RealFloat a => B.Vector (Dimension x) a -> a)
    -> c # x
    -> c #* x
{-# INLINE differential #-}
differential :: (forall a. RealFloat a => Vector (Dimension x) a -> a)
-> (c # x) -> c #* x
differential forall a. RealFloat a => Vector (Dimension x) a -> a
f = Vector Vector (Dimension x) Double -> c #* x
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector Vector (Dimension x) Double -> c #* x)
-> ((c # x) -> Vector Vector (Dimension x) Double)
-> (c # x)
-> c #* x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Vector (Dimension x) Double
-> Vector Vector (Dimension x) Double
forall (v :: Type -> Type) a (w :: Type -> Type) (n :: Nat).
(Vector v a, Vector w a) =>
Vector v n a -> Vector w n a
G.convert (Vector Vector (Dimension x) Double
 -> Vector Vector (Dimension x) Double)
-> ((c # x) -> Vector Vector (Dimension x) Double)
-> (c # x)
-> Vector Vector (Dimension x) Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s.
 Reifies s Tape =>
 Vector Vector (Dimension x) (Reverse s Double) -> Reverse s Double)
-> Vector Vector (Dimension x) Double
-> Vector Vector (Dimension x) Double
forall (f :: Type -> Type) a.
(Traversable f, Num a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a)
-> f a -> f a
D.grad forall a. RealFloat a => Vector (Dimension x) a -> a
forall s.
Reifies s Tape =>
Vector Vector (Dimension x) (Reverse s Double) -> Reverse s Double
f (Vector Vector (Dimension x) Double
 -> Vector Vector (Dimension x) Double)
-> ((c # x) -> Vector Vector (Dimension x) Double)
-> (c # x)
-> Vector Vector (Dimension x) Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (c # x) -> Vector Vector (Dimension x) Double
forall c x. (c # x) -> Vector (Dimension x) Double
boxCoordinates

-- | Computes the Hessian of a function at a point with automatic differentiation.
hessian
    :: Manifold x
    => (forall a. RealFloat a => B.Vector (Dimension x) a -> a)
    -> c # x
    -> c #* Tensor x x -- ^ The Hessian
{-# INLINE hessian #-}
hessian :: (forall a. RealFloat a => Vector (Dimension x) a -> a)
-> (c # x) -> c #* Tensor x x
hessian forall a. RealFloat a => Vector (Dimension x) a -> a
f c # x
p =
    Matrix Vector (Dimension x) (Dimension x) Double -> c #* Tensor x x
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix (Matrix Vector (Dimension x) (Dimension x) Double
 -> c #* Tensor x x)
-> (Vector Vector (Dimension x) (Vector (Dimension x) Double)
    -> Matrix Vector (Dimension x) (Dimension x) Double)
-> Vector Vector (Dimension x) (Vector (Dimension x) Double)
-> c #* Tensor x x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector (Dimension x) (Vector (Dimension x) Double)
-> Matrix Vector (Dimension x) (Dimension x) Double
forall (n :: Nat) x (m :: Nat).
(KnownNat n, Storable x) =>
Vector m (Vector n x) -> Matrix m n x
S.fromRows (Vector (Dimension x) (Vector (Dimension x) Double)
 -> Matrix Vector (Dimension x) (Dimension x) Double)
-> (Vector Vector (Dimension x) (Vector (Dimension x) Double)
    -> Vector (Dimension x) (Vector (Dimension x) Double))
-> Vector Vector (Dimension x) (Vector (Dimension x) Double)
-> Matrix Vector (Dimension x) (Dimension x) Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Vector (Dimension x) (Vector (Dimension x) Double)
-> Vector (Dimension x) (Vector (Dimension x) Double)
forall (v :: Type -> Type) a (w :: Type -> Type) (n :: Nat).
(Vector v a, Vector w a) =>
Vector v n a -> Vector w n a
G.convert (Vector Vector (Dimension x) (Vector (Dimension x) Double)
 -> c #* Tensor x x)
-> Vector Vector (Dimension x) (Vector (Dimension x) Double)
-> c #* Tensor x x
forall a b. (a -> b) -> a -> b
$ Vector Vector (Dimension x) Double -> Vector (Dimension x) Double
forall (v :: Type -> Type) a (w :: Type -> Type) (n :: Nat).
(Vector v a, Vector w a) =>
Vector v n a -> Vector w n a
G.convert (Vector Vector (Dimension x) Double -> Vector (Dimension x) Double)
-> Vector Vector (Dimension x) (Vector Vector (Dimension x) Double)
-> Vector Vector (Dimension x) (Vector (Dimension x) Double)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall s.
 Reifies s Tape =>
 Vector Vector (Dimension x) (On (Reverse s (Sparse Double)))
 -> On (Reverse s (Sparse Double)))
-> Vector Vector (Dimension x) Double
-> Vector Vector (Dimension x) (Vector Vector (Dimension x) Double)
forall (f :: Type -> Type) a.
(Traversable f, Num a) =>
(forall s.
 Reifies s Tape =>
 f (On (Reverse s (Sparse a))) -> On (Reverse s (Sparse a)))
-> f a -> f (f a)
D.hessian forall a. RealFloat a => Vector (Dimension x) a -> a
forall s.
Reifies s Tape =>
Vector Vector (Dimension x) (On (Reverse s (Sparse Double)))
-> On (Reverse s (Sparse Double))
f ((c # x) -> Vector Vector (Dimension x) Double
forall c x. (c # x) -> Vector (Dimension x) Double
boxCoordinates c # x
p)

-- | A class of 'Map's which can 'propagate' errors. That is, given an error
-- derivative on the output, the input which caused the output, and a
-- 'Map' to derive, return the derivative of the error with respect to the
-- parameters of the 'Map', as well as the output of the 'Map'.
class Map c f y x => Propagate c f y x where
    propagate :: [c #* y] -- ^ The error differential
              -> [c #* x] -- ^ A vector of inputs
              -> c # f y x -- ^ The function to differentiate
              -> (c #* f y x, [c # y]) -- ^ The derivative, and function output

-- | Distance between two 'Point's based on the 'Euclidean' metric (l2 distance).
euclideanDistance
    :: Manifold x => c # x -> c # x -> Double
{-# INLINE euclideanDistance #-}
euclideanDistance :: (c # x) -> (c # x) -> Double
euclideanDistance c # x
xs c # x
ys = Vector (Dimension x) Double -> Double
forall (k :: Nat). KnownNat k => Vector k Double -> Double
S.l2Norm ((c # x) -> Vector (Dimension x) Double
forall c x. Point c x -> Vector (Dimension x) Double
coordinates ((c # x) -> Vector (Dimension x) Double)
-> (c # x) -> Vector (Dimension x) Double
forall a b. (a -> b) -> a -> b
$ c # x
xs (c # x) -> (c # x) -> c # x
forall a. Num a => a -> a -> a
- c # x
ys)

-- | An implementation of backpropagation using the 'Propagate' class. The first
-- argument is a function which takes a generalized target output and function
-- output and returns an error. The second argument is a list of target outputs
-- and function inputs. The third argument is the parameteric function to be
-- optimized, and its differential is what is returned.
backpropagation
    :: Propagate c f y x
    => (a -> c # y -> c #* y)
    -> [(a, c #* x)]
    -> c # f y x
    -> c #* f y x
{-# INLINE backpropagation #-}
backpropagation :: (a -> (c # y) -> c #* y)
-> [(a, c #* x)] -> (c # f y x) -> c #* f y x
backpropagation a -> (c # y) -> c #* y
grd [(a, c #* x)]
ysxs c # f y x
f =
    let ([a]
yss,[c #* x]
xs) = [(a, c #* x)] -> ([a], [c #* x])
forall a b. [(a, b)] -> ([a], [b])
unzip [(a, c #* x)]
ysxs
        (c #* f y x
df,[c # y]
yhts) = [c #* y] -> [c #* x] -> (c # f y x) -> (c #* f y x, [c # y])
forall c (f :: Type -> Type -> Type) y x.
Propagate c f y x =>
[c #* y] -> [c #* x] -> (c # f y x) -> (c #* f y x, [c # y])
propagate [c #* y]
dys [c #* x]
xs c # f y x
f
        dys :: [c #* y]
dys = (a -> (c # y) -> c #* y) -> [a] -> [c # y] -> [c #* y]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> (c # y) -> c #* y
grd [a]
yss [c # y]
yhts
     in c #* f y x
df


--- Riemannian Manifolds ---


-- | 'Riemannian' 'Manifold's are differentiable 'Manifold's associated with a
-- smoothly varying 'Tensor' known as the Riemannian 'metric'. 'flat' and
-- 'sharp' correspond to applying this 'metric' to elements of the 'Primal' and
-- 'Dual' spaces, respectively.
class (Primal c, Manifold x) => Riemannian c x where
    metric :: c # x -> c #* Tensor x x
    flat :: c # x -> c # x -> c #* x
    {-# INLINE flat #-}
    flat c # x
p c # x
v = (c # x) -> c #* Tensor x x
forall c x. Riemannian c x => (c # x) -> c #* Tensor x x
metric c # x
p (c #* Tensor x x) -> (Dual c #* x) -> c #* x
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> (c #* x) -> c # y
>.> c # x
Dual c #* x
v
    sharp :: c # x -> c #* x -> c # x
    {-# INLINE sharp #-}
    sharp c # x
p c #* x
v = (c #* Tensor x x) -> Dual c #* Tensor x x
forall x y c.
(Manifold x, Manifold y, Dimension x ~ Dimension y) =>
(c # Tensor y x) -> c #* Tensor x y
inverse ((c # x) -> c #* Tensor x x
forall c x. Riemannian c x => (c # x) -> c #* Tensor x x
metric c # x
p) Point c (Tensor x x) -> (c #* x) -> c # x
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> (c #* x) -> c # y
>.> c #* x
v


--- Dually Flat Manifolds ---


-- | Although convex analysis is usually developed seperately from differential
-- geometry, it arises naturally out of the theory of dually flat 'Manifold's (<https://books.google.com/books?hl=en&lr=&id=vc2FWSo7wLUC&oi=fnd&pg=PR7&dq=methods+of+information+geometry&ots=4HsxHD_5KY&sig=gURe0tA3IEO-z-Cht_2TNsjjOG8#v=onepage&q=methods%20of%20information%20geometry&f=false Amari and Nagaoka, 2000>).
--
-- A 'Manifold' is 'Legendre' if it is associated with a particular convex
-- function known as a 'potential'.
class ( Primal (PotentialCoordinates x), Manifold x ) => Legendre x where
    potential :: PotentialCoordinates x # x -> Double

-- | The (natural) coordinates of the given 'Manifold', on which the 'potential'
-- is defined.
type family PotentialCoordinates x :: Type

-- | A 'Manifold' is 'DuallyFlat' when we can describe the 'dualPotential', which
-- is the convex conjugate of 'potential'.
class Legendre x => DuallyFlat x where
    dualPotential :: PotentialCoordinates x #* x -> Double

-- | Computes the 'canonicalDivergence' between two points. Note that relative
-- to the typical definition of the KL-Divergence/relative entropy, the
-- arguments of this function are flipped.
canonicalDivergence
    :: DuallyFlat x => PotentialCoordinates x # x -> PotentialCoordinates x #* x -> Double
{-# INLINE canonicalDivergence #-}
canonicalDivergence :: (PotentialCoordinates x # x)
-> (PotentialCoordinates x #* x) -> Double
canonicalDivergence PotentialCoordinates x # x
pp PotentialCoordinates x #* x
dq = (PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
pp Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (PotentialCoordinates x #* x) -> Double
forall x. DuallyFlat x => (PotentialCoordinates x #* x) -> Double
dualPotential PotentialCoordinates x #* x
dq Double -> Double -> Double
forall a. Num a => a -> a -> a
- (PotentialCoordinates x # x
pp (PotentialCoordinates x # x)
-> (PotentialCoordinates x #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> PotentialCoordinates x #* x
dq)


--- Instances ---


-- Euclidean --

instance KnownNat k => Riemannian Cartesian (Euclidean k) where
    {-# INLINE metric #-}
    metric :: (Cartesian # Euclidean k)
-> Cartesian #* Tensor (Euclidean k) (Euclidean k)
metric Cartesian # Euclidean k
_ = Matrix (Dimension (Euclidean k)) (Dimension (Euclidean k)) Double
-> Cartesian # Tensor (Euclidean k) (Euclidean k)
forall y x c.
Matrix (Dimension y) (Dimension x) Double -> c # Tensor y x
fromMatrix Matrix (Dimension (Euclidean k)) (Dimension (Euclidean k)) Double
forall (n :: Nat) x. (KnownNat n, Numeric x, Num x) => Matrix n n x
S.matrixIdentity
    {-# INLINE flat #-}
    flat :: (Cartesian # Euclidean k)
-> (Cartesian # Euclidean k) -> Cartesian #* Euclidean k
flat Cartesian # Euclidean k
_ = (Cartesian # Euclidean k) -> Cartesian #* Euclidean k
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint
    {-# INLINE sharp #-}
    sharp :: (Cartesian # Euclidean k)
-> (Cartesian #* Euclidean k) -> Cartesian # Euclidean k
sharp Cartesian # Euclidean k
_ = (Cartesian #* Euclidean k) -> Cartesian # Euclidean k
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

-- Replicated Riemannian Manifolds --

--instance {-# OVERLAPPABLE #-} (Riemannian c x, KnownNat k) => Riemannian c (Replicated k x) where
--    metric = error "Do not call metric on a replicated manifold"
--    {-# INLINE flat #-}
--    flat = S.map flat
--    {-# INLINE sharp #-}
--    sharp = S.map sharp

-- Backprop --

instance (Bilinear Tensor y x, Primal c) => Propagate c Tensor y x where
    {-# INLINE propagate #-}
    propagate :: [c #* y]
-> [c #* x] -> (c # Tensor y x) -> (c #* Tensor y x, [c # y])
propagate [c #* y]
dps [c #* x]
qs c # Tensor y x
pq = ([c #* y]
dps [c #* y] -> [c #* x] -> c #* Tensor y x
forall (f :: Type -> Type -> Type) y x c.
Bilinear f y x =>
[c # y] -> [c # x] -> c # f y x
>$< [c #* x]
qs, c # Tensor y x
pq (c # Tensor y x) -> [c #* x] -> [c # y]
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> [c #* x] -> [c # y]
>$> [c #* x]
qs)

--instance (Bilinear Tensor y x, Primal c) => Propagate c Tensor y x where
--    {-# INLINE propagate #-}
--    propagate dps qs pq =
--        let foldfun (dp,q) (k,dpq) = (k+1,(dp >.< q) + dpq)
--         in (uncurry (/>) . foldr foldfun (0,0) $ zip dps qs, pq >$> qs)

instance (Translation z y, Map c (Affine f y) z x, Propagate c f y x)
  => Propagate c (Affine f y) z x where
    {-# INLINE propagate #-}
    propagate :: [c #* z]
-> [c #* x]
-> (c # Affine f y z x)
-> (c #* Affine f y z x, [c # z])
propagate [c #* z]
dzs [c #* x]
xs c # Affine f y z x
fzx =
        let z :: c # z
            yx :: c # f y x
            (c # z
z,c # f y x
yx) = (c # Affine f y z x)
-> (c # First (Affine f y z x), c # Second (Affine f y z x))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # Affine f y z x
fzx
            dys :: [Dual c # y]
dys = (c #* z) -> Dual c # y
forall z y c. Translation z y => (c # z) -> c # y
anchor ((c #* z) -> Dual c # y) -> [c #* z] -> [Dual c # y]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [c #* z]
dzs
            (Point (Dual c) (f y x)
dyx,[c # y]
ys) = [Dual c # y]
-> [c #* x] -> (c # f y x) -> (Point (Dual c) (f y x), [c # y])
forall c (f :: Type -> Type -> Type) y x.
Propagate c f y x =>
[c #* y] -> [c #* x] -> (c # f y x) -> (c #* f y x, [c # y])
propagate [Dual c # y]
dys [c #* x]
xs c # f y x
yx
         in ((Dual c # First (Affine f y z x))
-> (Dual c # Second (Affine f y z x)) -> c #* Affine f y z x
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join ([c #* z] -> c #* z
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average [c #* z]
dzs) Point (Dual c) (f y x)
Dual c # Second (Affine f y z x)
dyx, (c # z
z (c # z) -> (c # y) -> c # z
forall z y c. Translation z y => (c # z) -> (c # y) -> c # z
>+>) ((c # y) -> c # z) -> [c # y] -> [c # z]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [c # y]
ys)


-- Sums --

type instance PotentialCoordinates (x,y) = PotentialCoordinates x

instance (Legendre x, Legendre y, PotentialCoordinates x ~ PotentialCoordinates y)
  => Legendre (x,y) where
      {-# INLINE potential #-}
      potential :: (PotentialCoordinates (x, y) # (x, y)) -> Double
potential PotentialCoordinates (x, y) # (x, y)
pmn =
          let (PotentialCoordinates y # x
pm,PotentialCoordinates y # y
pn) = (PotentialCoordinates y # (x, y))
-> (PotentialCoordinates y # First (x, y),
    PotentialCoordinates y # Second (x, y))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split PotentialCoordinates y # (x, y)
PotentialCoordinates (x, y) # (x, y)
pmn
           in (PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
PotentialCoordinates y # x
pm Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (PotentialCoordinates y # y) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates y # y
pn

type instance PotentialCoordinates (Replicated k x) = PotentialCoordinates x

instance (Legendre x, KnownNat k) => Legendre (Replicated k x) where
    {-# INLINE potential #-}
    potential :: (PotentialCoordinates (Replicated k x) # Replicated k x) -> Double
potential PotentialCoordinates (Replicated k x) # Replicated k x
ps =
        Vector k Double -> Double
forall a (n :: Nat). (Storable a, Num a) => Vector n a -> a
S.sum (Vector k Double -> Double) -> Vector k Double -> Double
forall a b. (a -> b) -> a -> b
$ ((PotentialCoordinates x # x) -> Double)
-> (PotentialCoordinates x # Replicated k x) -> Vector k Double
forall a (k :: Nat) x c.
(Storable a, KnownNat k, Manifold x) =>
((c # x) -> a) -> (c # Replicated k x) -> Vector k a
mapReplicated (PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # Replicated k x
PotentialCoordinates (Replicated k x) # Replicated k x
ps

instance (DuallyFlat x, KnownNat k) => DuallyFlat (Replicated k x) where
    {-# INLINE dualPotential #-}
    dualPotential :: (PotentialCoordinates (Replicated k x) #* Replicated k x) -> Double
dualPotential PotentialCoordinates (Replicated k x) #* Replicated k x
ps =
        Vector k Double -> Double
forall a (n :: Nat). (Storable a, Num a) => Vector n a -> a
S.sum (Vector k Double -> Double) -> Vector k Double -> Double
forall a b. (a -> b) -> a -> b
$ ((Dual (PotentialCoordinates x) # x) -> Double)
-> (Dual (PotentialCoordinates x) # Replicated k x)
-> Vector k Double
forall a (k :: Nat) x c.
(Storable a, KnownNat k, Manifold x) =>
((c # x) -> a) -> (c # Replicated k x) -> Vector k a
mapReplicated (Dual (PotentialCoordinates x) # x) -> Double
forall x. DuallyFlat x => (PotentialCoordinates x #* x) -> Double
dualPotential Dual (PotentialCoordinates x) # Replicated k x
PotentialCoordinates (Replicated k x) #* Replicated k x
ps