module Goal.Geometry.Differential (
Tangent (Tangent, removeTangent)
, Bundle (Bundle, removeBundle)
, Partials (Partials)
, Differentials (Differentials)
, gradientStep
, projectTangent
, tangentToBundle
, bundleToTangent
, Riemannian (metric, flat, sharp)
, gradientAscent
, vanillaGradientAscent
, gradientDescent
, vanillaGradientDescent
) where
import Prelude hiding (map,minimum,maximum)
import Goal.Core
import Goal.Geometry.Set
import Goal.Geometry.Manifold
import Goal.Geometry.Linear
import Goal.Geometry.Map
import Goal.Geometry.Map.Multilinear
import qualified Data.Vector.Storable as C
import qualified Numeric.LinearAlgebra.HMatrix as H
newtype Tangent c m = Tangent { removeTangent :: c :#: m } deriving (Eq, Read, Show)
newtype Bundle c m = Bundle { removeBundle :: m } deriving (Eq, Read, Show)
data Partials = Partials deriving (Eq, Read, Show)
data Differentials = Differentials deriving (Eq, Read, Show)
gradientStep :: Manifold m => Double -> Partials :#: Tangent c m -> c :#: m
gradientStep eps f' =
let (Tangent p) = manifold f'
x' = coordinates $ eps .> f'
in fromCoordinates (manifold p) (coordinates p + x')
projectTangent :: d :#: Tangent c m -> c :#: m
projectTangent = removeTangent . manifold
bundleToTangent :: Manifold m => c :#: Bundle d m -> c :#: Tangent d m
bundleToTangent p =
let (cs,dcs) = C.splitAt (div (dimension $ manifold p) 2) $ coordinates p
(Bundle m) = manifold p
in fromCoordinates (Tangent $ fromCoordinates m cs) dcs
tangentToBundle :: Manifold m => c :#: Tangent d m -> c :#: Bundle d m
tangentToBundle cm =
let (Tangent dm) = manifold cm
m = manifold dm
in fromCoordinates (Bundle m) $ coordinates dm C.++ coordinates cm
replicatedTangents :: Manifold m => d :#: Tangent c (Replicated m) -> [d :#: Tangent c m]
replicatedTangents dp =
let (Tangent p) = manifold dp
ts = mapReplicated Tangent p
cs = listCoordinates dp
in zipWith fromList ts $ breakEvery (dimension $ head ts) cs
gradientAscent :: (Riemannian c m, Manifold m)
=> Double
-> (c :#: m -> Differentials :#: Tangent c m)
-> (c :#: m)
-> [c :#: m]
gradientAscent eps f' = iterate (gradientStep eps . sharp . f')
vanillaGradientAscent :: Manifold m
=> Double
-> (c :#: m -> Differentials :#: Tangent c m)
-> (c :#: m)
-> [c :#: m]
vanillaGradientAscent eps f' = iterate (gradientStep eps . breakChart . f')
gradientDescent :: (Riemannian c m, Manifold m)
=> Double
-> (c :#: m -> Differentials :#: Tangent c m)
-> (c :#: m)
-> [c :#: m]
gradientDescent eps = gradientAscent (eps)
vanillaGradientDescent :: Manifold m
=> Double
-> (c :#: m -> Differentials :#: Tangent c m)
-> (c :#: m)
-> [c :#: m]
vanillaGradientDescent eps = vanillaGradientAscent (eps)
class Manifold m => Riemannian c m where
metric :: c :#: m -> Function Partials Differentials :#: Tensor (Tangent c m) (Tangent c m)
flat :: Partials :#: Tangent c m -> Differentials :#: Tangent c m
flat p = matrixApply (metric $ projectTangent p) p
sharp :: Differentials :#: Tangent c m -> Partials :#: Tangent c m
sharp p = matrixApply (matrixInverse . metric $ projectTangent p) p
instance (Manifold m, Riemannian c m) => Riemannian c (Replicated m) where
metric p =
let mtxs = mapReplicated (toHMatrix . metric) p
in fromHMatrix (Tensor (Tangent p) (Tangent p)) $ H.diagBlock mtxs
flat dp =
fromCoordinates (manifold dp) . C.concat $ coordinates . flat <$> replicatedTangents dp
sharp dp =
fromCoordinates (manifold dp) . C.concat $ coordinates . sharp <$> replicatedTangents dp
instance Riemannian Cartesian Continuum where
metric p = fromList (Tensor (Tangent p) (Tangent p)) [1]
flat = breakChart
sharp = breakChart
instance Riemannian Cartesian Euclidean where
metric p = fromHMatrix (Tensor (Tangent p) (Tangent p)) . H.ident . dimension $ manifold p
flat = breakChart
sharp = breakChart
instance (Manifold m, Riemannian c m) => Riemannian Partials (Tangent c m) where
metric dp =
fromCoordinates (Tensor (Tangent dp) (Tangent dp)) . coordinates . metric $ projectTangent dp
sharp ddp = fromCoordinates (manifold ddp) . coordinates
. sharp . fromCoordinates (manifold $ projectTangent ddp) $ coordinates ddp
flat pdd = fromCoordinates (manifold pdd) . coordinates
. flat . fromCoordinates (manifold $ projectTangent pdd) $ coordinates pdd
instance Manifold m => Manifold (Tangent c m) where
dimension (Tangent p) = dimension $ manifold p
instance Manifold m => Manifold (Bundle c m) where
dimension (Bundle m) = 2 * dimension m
instance Primal Partials where
type Dual Partials = Differentials
instance Primal Differentials where
type Dual Differentials = Partials