-- | This module provides tools for working with differential and Riemannian
-- geometry.
module Goal.Geometry.Differential (
    -- * Tangent Spaces
    -- ** Types
      Tangent (Tangent, removeTangent)
    , Bundle (Bundle, removeBundle)
    , Partials (Partials)
    , Differentials (Differentials)
    -- ** Functions
    , gradientStep
    , projectTangent
    , tangentToBundle
    , bundleToTangent
    -- * Riemannian Manifolds
    , Riemannian (metric, flat, sharp)
    -- ** Gradient Pursuit
    , gradientAscent
    , vanillaGradientAscent
    , gradientDescent
    , vanillaGradientDescent
    ) where

--- Imports ---

import Prelude hiding (map,minimum,maximum)

-- Package --

import Goal.Core

import Goal.Geometry.Set
import Goal.Geometry.Manifold
import Goal.Geometry.Linear
import Goal.Geometry.Map
import Goal.Geometry.Map.Multilinear

-- Qualified --

import qualified Data.Vector.Storable as C
import qualified Numeric.LinearAlgebra.HMatrix as H

--import Data.Vector.Storable.UnsafeSerialize

--- Differentiable Manifolds ---

-- | 'Tangent' spaces on 'Manifold's are the basis for differential geometry.
-- 'Tangent' spaces are defined at each point on a differentiable 'Manifold'.
newtype Tangent c m = Tangent { removeTangent :: c :#: m } deriving (Eq, Read, Show)

-- | A 'Tangent' 'Bundle' is the original 'Manifold' combined with all its
-- 'Tangent' spaces.
newtype Bundle c m = Bundle { removeBundle :: m } deriving (Eq, Read, Show)

-- | The 'Partials' coordinate system is defined as the partial derivatives of
-- the coordinate functions at a particular point.
data Partials = Partials deriving (Eq, Read, Show)

-- | The 'Differentials' coordinate system represents the set of linear
-- functionals on the 'Tangent' space.
data Differentials = Differentials deriving (Eq, Read, Show)

gradientStep :: Manifold m => Double -> Partials :#: Tangent c m -> c :#: m
-- | 'gradientStep' follows takes a gradient in a particular tangent space and
-- transforms the point underlying the given tangent space by shifting it
-- slightly in the direction of the gradient.
gradientStep eps f' =
    let (Tangent p) = manifold f'
        x' = coordinates $ eps .> f'
     in fromCoordinates (manifold p) (coordinates p + x')

projectTangent :: d :#: Tangent c m -> c :#: m
-- | Returns the underlying 'Point' from a 'Tangent' vector.
projectTangent = removeTangent . manifold

bundleToTangent :: Manifold m => c :#: Bundle d m -> c :#: Tangent d m
-- | Converts a 'Point' on a 'Tangent' 'Bundle' into a 'Tangent' vector.
bundleToTangent p =
    let (cs,dcs) = C.splitAt (div (dimension $ manifold p) 2) $ coordinates p
        (Bundle m) = manifold p
     in fromCoordinates (Tangent $ fromCoordinates m cs) dcs

tangentToBundle :: Manifold m => c :#: Tangent d m -> c :#: Bundle d m
-- | Converts  a 'Tangent' vector into a 'Point' on a 'Tangent' 'Bundle'.
tangentToBundle cm =
    let (Tangent dm) = manifold cm
        m = manifold dm
     in fromCoordinates (Bundle m) $ coordinates dm C.++ coordinates cm

replicatedTangents :: Manifold m => d :#: Tangent c (Replicated m) -> [d :#: Tangent c m]
-- | Converts a 'Tangent' vector on a 'Replicated' 'Manifold' into a list of
-- 'Tangent' vectors.
replicatedTangents dp =
    let (Tangent p) = manifold dp
        ts = mapReplicated Tangent p
        cs = listCoordinates dp
     in zipWith fromList ts $ breakEvery (dimension $ head ts) cs

-- Gradient Pursuit --

gradientAscent :: (Riemannian c m, Manifold m)
    => Double -- ^ Step size
    -> (c :#: m -> Differentials :#: Tangent c m) -- ^ Gradient calculator
    -> (c :#: m) -- ^ The initial point
    -> [c :#: m] -- ^ The gradient ascent
gradientAscent eps f' = iterate (gradientStep eps . sharp . f')

vanillaGradientAscent :: Manifold m
    => Double -- ^ Step size
    -> (c :#: m -> Differentials :#: Tangent c m) -- ^ Gradient calculator
    -> (c :#: m) -- ^ The initial point
    -> [c :#: m] -- ^ The gradient ascent
vanillaGradientAscent eps f' = iterate (gradientStep eps . breakChart . f')

gradientDescent :: (Riemannian c m, Manifold m)
    => Double -- ^ Step size
    -> (c :#: m -> Differentials :#: Tangent c m) -- ^ Gradient calculator
    -> (c :#: m) -- ^ The initial point
    -> [c :#: m] -- ^ The gradient ascent
gradientDescent eps = gradientAscent (-eps)

vanillaGradientDescent :: Manifold m
    => Double -- ^ Step size
    -> (c :#: m -> Differentials :#: Tangent c m) -- ^ Gradient calculator
    -> (c :#: m) -- ^ The initial point
    -> [c :#: m] -- ^ The gradient ascent
vanillaGradientDescent eps = vanillaGradientAscent (-eps)

--- Riemannian Manifolds ---

-- | 'Riemannian' 'Manifold's are differentiable 'Manifold's where associated
-- with each point in the 'Manifold' is a 'Tangent' space with a smoothly
-- varying inner product. 'flat' and 'sharp' correspond to lowering and
-- raising the indices via the musical isomorphism determined by the metric
-- tensor.
-- A 'Riemannian' 'Manifold' should should satisfy the law
-- > flat $ sharp p = p
class Manifold m => Riemannian c m where
    metric :: c :#: m -> Function Partials Differentials :#: Tensor (Tangent c m) (Tangent c m)
    flat :: Partials :#: Tangent c m -> Differentials :#: Tangent c m
    flat p = matrixApply (metric $ projectTangent p) p
    sharp :: Differentials :#: Tangent c m -> Partials :#: Tangent c m
    sharp p = matrixApply (matrixInverse . metric $ projectTangent p) p

--- Instances ---

-- Replicated --

instance (Manifold m, Riemannian c m) => Riemannian c (Replicated m) where
    metric p =
        let mtxs = mapReplicated (toHMatrix . metric) p
         in fromHMatrix (Tensor (Tangent p) (Tangent p)) $ H.diagBlock mtxs
    flat dp =
        fromCoordinates (manifold dp) . C.concat $ coordinates . flat <$> replicatedTangents dp
    sharp dp =
        fromCoordinates (manifold dp) . C.concat $ coordinates . sharp <$> replicatedTangents dp

-- Euclidean --

instance Riemannian Cartesian Continuum where
    metric p = fromList (Tensor (Tangent p) (Tangent p)) [1]
    flat = breakChart
    sharp = breakChart

instance Riemannian Cartesian Euclidean where
    metric p = fromHMatrix (Tensor (Tangent p) (Tangent p)) . H.ident . dimension $ manifold p
    flat = breakChart
    sharp = breakChart

-- Trivial higher order spaces --

instance (Manifold m, Riemannian c m) => Riemannian Partials (Tangent c m) where
    metric dp =
        fromCoordinates (Tensor (Tangent dp) (Tangent dp)) . coordinates . metric $ projectTangent dp
    sharp ddp = fromCoordinates (manifold ddp) . coordinates
        . sharp . fromCoordinates (manifold $ projectTangent ddp) $ coordinates ddp
    flat pdd = fromCoordinates (manifold pdd) . coordinates
        . flat . fromCoordinates (manifold $ projectTangent pdd) $ coordinates pdd

-- Tangent Spaces --

instance Manifold m => Manifold (Tangent c m) where
    dimension (Tangent p) = dimension $ manifold p

instance Manifold m => Manifold (Bundle c m) where
    dimension (Bundle m) = 2 * dimension m

-- Tanget Space Coordinates --

instance Primal Partials where
    type Dual Partials = Differentials

instance Primal Differentials where
    type Dual Differentials = Partials

--- Graveyard ---

--- Functions ---

pushForward :: (Manifold m, Manifold n)
    => Function c d :#: Tensor n m
    -> c :#: m
    -> Function Partials Partials :#: Tensor (Tangent d n) (Tangent c m)
-- | 'pushForward' takes a 'Map' between 'Manifold's and turns it into a map
-- between the 'Tangent' spaces of the 'Manifold's. Although this ought to be a
-- class, right now it's simply the trivial 'pushForward' as applied to linear
-- maps.
pushForward pq q = fromCoordinates (Tensor (Tangent $ matrixApply pq q) (Tangent q)) $ coordinates pq

pushForward0 :: (Manifold m, Manifold n)
    => Function c d :#: Tensor n m
    -> c :#: m
    -> d :#: n
    -> Function Partials Partials :#: Tensor (Tangent d n) (Tangent c m)
-- | 'pushForward0' takes a 'Map' between 'Manifold's and turns it into a map
-- between the 'Tangent' spaces of the 'Manifold's. In this version we can
-- specify the target space more directly.
pushForward0 pq q p = fromCoordinates (Tensor (Tangent p) (Tangent q)) $ coordinates pq