{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeOperators #-}

module Downhill.Metric
  ( MetricTensor (..)
  )
where

import Data.VectorSpace ((^+^))
import Downhill.Grad (Dual (evalGrad), Manifold(..), MScalar)

-- | @MetricTensor@ converts gradients to vectors.
--
-- It is really inverse of a metric tensor, because it maps cotangent
-- space into tangent space. Gradient descent doesn't need metric tensor,
-- it needs inverse.
class Dual (Tang p) (Grad p) => MetricTensor p g where
  -- | @m@ must be symmetric:
  --
  -- @evalGrad x (evalMetric m y) = evalGrad y (evalMetric m x)@
  evalMetric :: g -> Grad p -> Tang p

  -- | @innerProduct m x y = evalGrad x (evalMetric m y)@
  innerProduct :: g -> Grad p -> Grad p -> MScalar p
  innerProduct g
g Grad p
x Grad p
y = forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad @(Tang p) @(Grad p) Grad p
x (forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @p g
g Grad p
y)

  -- | @sqrNorm m x = innerProduct m x x@
  sqrNorm :: g -> Grad p -> MScalar p
  sqrNorm g
g Grad p
x = forall p g. MetricTensor p g => g -> Grad p -> Grad p -> MScalar p
innerProduct @p g
g Grad p
x Grad p
x

instance MetricTensor Integer Integer where
  evalMetric :: Integer -> Grad Integer -> Tang Integer
evalMetric Integer
m Grad Integer
x = Integer
m forall a. Num a => a -> a -> a
* Grad Integer
x

instance (MScalar a ~ MScalar b, MetricTensor a ma, MetricTensor b mb) => MetricTensor (a, b) (ma, mb) where
  evalMetric :: (ma, mb) -> Grad (a, b) -> Tang (a, b)
evalMetric (ma
ma, mb
mb) (Grad a
a, Grad b
b) = (forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @a ma
ma Grad a
a, forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @b mb
mb Grad b
b)
  sqrNorm :: (ma, mb) -> Grad (a, b) -> MScalar (a, b)
sqrNorm (ma
ma, mb
mb) (Grad a
a, Grad b
b) = forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @a ma
ma Grad a
a forall v. AdditiveGroup v => v -> v -> v
^+^ forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @b mb
mb Grad b
b

instance
  ( MScalar a ~ MScalar b,
    MScalar a ~ MScalar c,
    MetricTensor a ma,
    MetricTensor b mb,
    MetricTensor c mc
  ) =>
  MetricTensor (a, b, c) (ma, mb, mc)
  where
  evalMetric :: (ma, mb, mc) -> Grad (a, b, c) -> Tang (a, b, c)
evalMetric (ma
ma, mb
mb, mc
mc) (Grad a
a, Grad b
b, Grad c
c) = (forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @a ma
ma Grad a
a, forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @b mb
mb Grad b
b, forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @c mc
mc Grad c
c)
  sqrNorm :: (ma, mb, mc) -> Grad (a, b, c) -> MScalar (a, b, c)
sqrNorm (ma
ma, mb
mb, mc
mc) (Grad a
a, Grad b
b, Grad c
c) = forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @a ma
ma Grad a
a forall v. AdditiveGroup v => v -> v -> v
^+^ forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @b mb
mb Grad b
b forall v. AdditiveGroup v => v -> v -> v
^+^ forall p g. MetricTensor p g => g -> Grad p -> MScalar p
sqrNorm @c mc
mc Grad c
c

instance MetricTensor Float Float where
  evalMetric :: Float -> Grad Float -> Tang Float
evalMetric Float
m Grad Float
dv = Float
m forall a. Num a => a -> a -> a
* Grad Float
dv

instance MetricTensor Double Double where
  evalMetric :: Double -> Grad Double -> Tang Double
evalMetric Double
m Grad Double
dv = Double
m forall a. Num a => a -> a -> a
* Grad Double
dv

data L2 = L2

instance (Dual (Tang p) (Grad p), Grad p ~ Tang p) => MetricTensor p L2 where
  evalMetric :: L2 -> Grad p -> Tang p
evalMetric L2
L2 Grad p
v = Grad p
v