-- |
-- Module      : Data.Manifold.Riemannian
-- Copyright   : (c) Justus Sagemüller 2015
-- License     : GPL v3
-- 
-- Maintainer  : (@) sagemueller $ geo.uni-koeln.de
-- Stability   : experimental
-- Portability : portable
-- 
-- Riemannian manifolds are manifolds equipped with a 'Metric' at each point.
-- That means, these manifolds aren't merely topological objects anymore, but
-- have a geometry as well. This gives, in particular, a notion of distance
-- and shortest paths (geodesics) along which you can interpolate.
-- 
-- Keep in mind that the types in this library are
-- generally defined in an abstract-mathematical spirit, which may not always
-- match the intuition if you think about manifolds as embedded in ℝ³.
-- (For instance, the torus inherits its geometry from the decomposition as
-- @'S¹' × 'S¹'@, not from the “doughnut” embedding; the cone over @S¹@ is
-- simply treated as the unit disk, etc..)

{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE StandaloneDeriving         #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE DeriveFoldable             #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE FunctionalDependencies     #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE TupleSections              #-}
{-# LANGUAGE ParallelListComp           #-}
{-# LANGUAGE UnicodeSyntax              #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE PatternGuards              #-}
{-# LANGUAGE TypeOperators              #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE LiberalTypeSynonyms        #-}
{-# LANGUAGE CPP                        #-}
{-# LANGUAGE DataKinds                  #-}


module Data.Manifold.Riemannian  where


import Data.Maybe
import qualified Data.Vector as Arr
import Data.Semigroup

import Data.VectorSpace
import Data.LinearMap.HerMetric
import Data.AffineSpace

import Data.Manifold.Types
import Data.Manifold.Types.Primitive ((^), empty, embed, coEmbed)
import Data.Manifold.PseudoAffine
import Data.VectorSpace.FiniteDimensional
    
import Data.CoNat

import qualified Prelude as Hask hiding(foldl, sum, sequence)
import qualified Control.Applicative as Hask
import qualified Control.Monad       as Hask hiding(forM_, sequence)
import Data.Functor.Identity
import qualified Data.Foldable       as Hask
import qualified Data.Traversable as Hask

import qualified Numeric.LinearAlgebra.HMatrix as HMat

import Control.Category.Constrained.Prelude hiding
     ((^), all, elem, sum, forM, Foldable(..), Traversable)
import Control.Arrow.Constrained
import Control.Monad.Constrained hiding (forM)
import Data.Foldable.Constrained



class Semimanifold x => Geodesic x where
  geodesicBetween ::
          x -- ^ Starting point; the interpolation will yield this at -1.
       -> x -- ^ End point, for +1.
            -- 
            --   If the two points are actually connected by a path...
       -> Option ( -> x) -- ^ ...then this is the interpolation function. Attention: 
                           --   the type will change to 'Differentiable' in the future.

interpolate :: (Geodesic x, IntervalLike i) => x -> x -> Option (i -> x)
interpolate a b = (. toClosedInterval) <$> geodesicBetween a b




#define deriveAffineGD(x)                                         \
instance Geodesic x where {                                        \
  geodesicBetween a b = return $ alerp a b . (/2) . (+1) . xParamD¹ \
 }

deriveAffineGD ()

instance Geodesic (ZeroDim ) where
  geodesicBetween Origin Origin = return $ \_ -> Origin

instance (Geodesic a, Geodesic b) => Geodesic (a,b) where
  geodesicBetween (a,b) (α,β) = liftA2 (&&&) (geodesicBetween a α) (geodesicBetween b β)

instance (Geodesic a, Geodesic b, Geodesic c) => Geodesic (a,b,c) where
  geodesicBetween (a,b,c) (α,β,γ)
      = liftA3 (\ia ib ic t -> (ia t, ib t, ic t))
           (geodesicBetween a α) (geodesicBetween b β) (geodesicBetween c γ)

instance (KnownNat n) => Geodesic (FreeVect n ) where
  geodesicBetween (FreeVect v) (FreeVect w)
      = return $ \( t) -> let μv = (1-t)/2; μw = (t+1)/2
                            in FreeVect $ Arr.zipWith (\vi wi -> μv*vi + μw*wi) v w

instance (PseudoAffine v) => Geodesic (FinVecArrRep t v ) where
  geodesicBetween (FinVecArrRep v) (FinVecArrRep w)
   | HMat.size v>0 && HMat.size w>0
      = return $ \( t) -> let μv = (1-t)/2; μw = (t+1)/2
                            in FinVecArrRep $ HMat.scale μv v + HMat.scale μw w

instance (Geodesic v, WithField  HilbertSpace v)
             => Geodesic (Stiefel1 v) where
  geodesicBetween (Stiefel1 p') (Stiefel1 q')
      = (\f -> \( t) -> Stiefel1 . f .  $ g * tan (ϑ*t))
            <$> geodesicBetween p q
   where p = normalized p'; q = normalized q'
         l = magnitude $ p^-^q
         ϑ = asin $ l/2
         g = sqrt $ 4/l^2 - 1


instance Geodesic S⁰ where
  geodesicBetween PositiveHalfSphere PositiveHalfSphere = return $ const PositiveHalfSphere
  geodesicBetween NegativeHalfSphere NegativeHalfSphere = return $ const NegativeHalfSphere
  geodesicBetween _ _ = empty

instance Geodesic  where
  geodesicBetween ( φ) ( ϕ)
    | abs (φ-ϕ) < pi  = (>>> ) <$> geodesicBetween φ ϕ
    | φ > 0           = (>>>  . \ψ -> signum ψ*pi - ψ)
                        <$> geodesicBetween (pi-φ) (-ϕ-pi)
    | otherwise       = (>>>  . \ψ -> signum ψ*pi - ψ)
                        <$> geodesicBetween (-pi-φ) (pi-ϕ)


instance Geodesic (Cℝay S⁰) where
  geodesicBetween p q = (>>> fromℝ) <$> geodesicBetween (toℝ p) (toℝ q)
   where toℝ (Cℝay h PositiveHalfSphere) = h
         toℝ (Cℝay h NegativeHalfSphere) = -h
         fromℝ x | x>0        = Cℝay x PositiveHalfSphere
                 | otherwise  = Cℝay (-x) NegativeHalfSphere

instance Geodesic (CD¹ S⁰) where
  geodesicBetween p q = (>>> fromI) <$> geodesicBetween (toI p) (toI q)
   where toI (CD¹ h PositiveHalfSphere) = h
         toI (CD¹ h NegativeHalfSphere) = -h
         fromI x | x>0        = CD¹ x PositiveHalfSphere
                 | otherwise  = CD¹ (-x) NegativeHalfSphere

instance Geodesic (Cℝay ) where
  geodesicBetween p q = (>>> fromP) <$> geodesicBetween (toP p) (toP q)
   where fromP = fromInterior
         toP w = case toInterior w of {Option (Just i) -> i}

instance Geodesic (CD¹ ) where
  geodesicBetween p q = (>>> fromI) <$> geodesicBetween (toI p) (toI q)
   where toI (CD¹ h ( φ)) = (h*cos φ, h*sin φ)
         fromI (x,y) = CD¹ (sqrt $ x^2+y^2) ( $ atan2 y x)

instance Geodesic (Cℝay ) where
  geodesicBetween p q = (>>> fromP) <$> geodesicBetween (toP p) (toP q)
   where fromP = fromInterior
         toP w = case toInterior w of {Option (Just i) -> i}

instance Geodesic (CD¹ ) where
  geodesicBetween p q = (>>> fromI) <$> geodesicBetween (toI p) (toI q :: ℝ³)
   where toI (CD¹ h sph) = h *^ embed sph
         fromI v = CD¹ (magnitude v) (coEmbed v)

#define geoVSpCone(c,t)                                               \
instance (c) => Geodesic (Cℝay (t)) where {                            \
  geodesicBetween p q = (>>> fromP) <$> geodesicBetween (toP p) (toP q) \
   where { fromP (x,0) = Cℝay 0 x                                        \
         ; fromP (x,h) = Cℝay h (x^/h)                                    \
         ; toP (Cℝay h w) = ( h*^w, h ) } } ;                              \
instance (c) => Geodesic (CD¹ (t)) where {                                  \
  geodesicBetween p q = (>>> fromP) <$> geodesicBetween (toP p) (toP q)      \
   where { fromP (x,0) = CD¹ 0 x                                              \
         ; fromP (x,h) = CD¹ h (x^/h)                                          \
         ; toP (CD¹ h w) = ( h*^w, h ) } }

geoVSpCone ((), )
geoVSpCone ((), ℝ⁰)
geoVSpCone ((WithField  HilbertSpace a, WithField  HilbertSpace b, Geodesic (a,b)), (a,b))
geoVSpCone (KnownNat n, FreeVect n )
geoVSpCone ((Geodesic v, WithField  HilbertSpace v), FinVecArrRep t v )




-- | One-dimensional manifolds, whose closure is homeomorpic to the unit interval.
class WithField  PseudoAffine i => IntervalLike i where
  toClosedInterval :: i ->  -- Differentiable ℝ i D¹

instance IntervalLike  where
  toClosedInterval = id
instance IntervalLike (CD¹ S⁰) where
  toClosedInterval (CD¹ h PositiveHalfSphere) =  h
  toClosedInterval (CD¹ h NegativeHalfSphere) =  (-h)
instance IntervalLike (Cℝay S⁰) where
  toClosedInterval (Cℝay h PositiveHalfSphere) =  $ tanh h
  toClosedInterval (Cℝay h NegativeHalfSphere) =  $ -tanh h
instance IntervalLike (CD¹ ℝ⁰) where
  toClosedInterval (CD¹ h Origin) =  $ h*2 - 1
instance IntervalLike (Cℝay ℝ⁰) where
  toClosedInterval (Cℝay h Origin) =  $ 1 - 2/(h+1)
instance IntervalLike  where
  toClosedInterval x =  $ tanh x





class Geodesic m => Riemannian m where
  rieMetric :: RieMetric m

instance Riemannian  where
  rieMetric = const m where m = projector 1