{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.Grad
  ( Dual (..),
    MetricTensor (..),
    HasGrad (..),
    GradBuilder,
    HasFullGrad,
    HasGradAffine,
  )
where

import Data.AffineSpace (AffineSpace (Diff))
import Data.Kind (Type)
import Data.VectorSpace (AdditiveGroup ((^+^)), VectorSpace (Scalar, (*^)))
import qualified Data.VectorSpace as VectorSpace
import Downhill.Linear.Expr (BasicVector (VecBuilder), FullVector)
import GHC.Generics (Generic)

-- | Dual of a vector @v@ is a linear map @v -> Scalar v@.
class
  ( AdditiveGroup s,
    VectorSpace v,
    VectorSpace dv,
    VectorSpace.Scalar v ~ s,
    VectorSpace.Scalar dv ~ s
  ) =>
  Dual s v dv
  where
  -- if evalGrad goes to HasGrad class, parameter p is ambiguous
  evalGrad :: dv -> v -> s

-- | @MetricTensor@ converts gradients to vectors.
--
-- It is really inverse of a metric tensor, because it maps cotangent
-- space into tangent space. Gradient descent doesn't need metric tensor,
-- it needs inverse.

class
  ( Dual (Scalar g) (MtVector g) (MtCovector g),
    VectorSpace g
  ) =>
  MetricTensor g
  where
  type MtVector g :: Type
  type MtCovector g :: Type

  -- | @m@ must be symmetric:
  --
  -- @evalGrad x (evalMetric m y) = evalGrad y (evalMetric m x)@
  evalMetric :: g -> MtCovector g -> MtVector g

  -- | @innerProduct m x y = evalGrad x (evalMetric m y)@
  innerProduct :: g -> MtCovector g -> MtCovector g -> Scalar g
  innerProduct g
g MtCovector g
x MtCovector g
y = MtCovector g -> MtVector g -> Scalar g
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad MtCovector g
x (g -> MtCovector g -> MtVector g
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric g
g MtCovector g
y)

  -- | @sqrNorm m x = innerProduct m x x@
  sqrNorm :: g -> MtCovector g -> Scalar g
  sqrNorm g
g MtCovector g
x = g -> MtCovector g -> MtCovector g -> Scalar g
forall g.
MetricTensor g =>
g -> MtCovector g -> MtCovector g -> Scalar g
innerProduct g
g MtCovector g
x MtCovector g
x

-- | @HasGrad@ is a collection of types and constraints that are useful
-- in many places. It helps to keep type signatures short.

-- TODO: FullVector or not?
-- TODO: Metric or not?
class
  ( Dual (MScalar p) (Tang p) (Grad p),
    MetricTensor (Metric p),
    MtVector (Metric p) ~ Tang p,
    MtCovector (Metric p) ~ Grad p,
    BasicVector (Tang p),
    BasicVector (Grad p)
  ) =>
  HasGrad p
  where
  -- | Scalar of @Tang p@ and @Grad p@.
  type MScalar p :: Type

  -- | Tangent vector of manifold @p@. If p is 'AffineSpace', @Tang p@ should
  -- be @'Diff' p@. If @p@ is 'VectorSpace', @Tang p@ might be the same as @p@ itself.
  type Tang p :: Type

  -- | Dual of tangent space of @p@.
  type Grad p :: Type

  -- | A 'MetricTensor'.
  type Metric p :: Type

type GradBuilder v = VecBuilder (Grad v)

type HasFullGrad p = (HasGrad p, FullVector (Grad p))

type HasGradAffine p =
  ( AffineSpace p,
    HasGrad p,
    HasGrad (Tang p),
    Tang p ~ Diff p,
    Tang (Tang p) ~ Tang p,
    Grad (Tang p) ~ Grad p
  )

instance Dual Integer Integer Integer where
  evalGrad :: Integer -> Integer -> Integer
evalGrad = Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(*)

instance MetricTensor Integer where
  type MtVector Integer = Integer
  type MtCovector Integer = Integer
  evalMetric :: Integer -> MtCovector Integer -> MtVector Integer
evalMetric Integer
m MtCovector Integer
x = Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
MtCovector Integer
x

instance HasGrad Integer where
  type MScalar Integer = Integer
  type Tang Integer = Integer
  type Grad Integer = Integer
  type Metric Integer = Integer

instance (Dual s a da, Dual s b db) => Dual s (a, b) (da, db) where
  evalGrad :: (da, db) -> (a, b) -> s
evalGrad (da
a, db
b) (a
x, b
y) = da -> a -> s
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad da
a a
x s -> s -> s
forall v. AdditiveGroup v => v -> v -> v
^+^ db -> b -> s
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad db
b b
y

instance (Dual s a da, Dual s b db, Dual s c dc) => Dual s (a, b, c) (da, db, dc) where
  evalGrad :: (da, db, dc) -> (a, b, c) -> s
evalGrad (da
a, db
b, dc
c) (a
x, b
y, c
z) = da -> a -> s
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad da
a a
x s -> s -> s
forall v. AdditiveGroup v => v -> v -> v
^+^ db -> b -> s
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad db
b b
y s -> s -> s
forall v. AdditiveGroup v => v -> v -> v
^+^ dc -> c -> s
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad dc
c c
z

instance (MetricTensor ma, MetricTensor mb, Scalar ma ~ Scalar mb) => MetricTensor (ma, mb) where
  type MtVector (ma, mb) = (MtVector ma, MtVector mb)
  type MtCovector (ma, mb) = (MtCovector ma, MtCovector mb)
  evalMetric :: (ma, mb) -> MtCovector (ma, mb) -> MtVector (ma, mb)
evalMetric (ma
ma, mb
mb) (a, b) = (ma -> MtCovector ma -> MtVector ma
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric ma
ma MtCovector ma
a, mb -> MtCovector mb -> MtVector mb
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric mb
mb MtCovector mb
b)
  sqrNorm :: (ma, mb) -> MtCovector (ma, mb) -> Scalar (ma, mb)
sqrNorm (ma
ma, mb
mb) (a, b) = ma -> MtCovector ma -> Scalar ma
forall g. MetricTensor g => g -> MtCovector g -> Scalar g
sqrNorm ma
ma MtCovector ma
a Scalar mb -> Scalar mb -> Scalar mb
forall v. AdditiveGroup v => v -> v -> v
^+^ mb -> MtCovector mb -> Scalar mb
forall g. MetricTensor g => g -> MtCovector g -> Scalar g
sqrNorm mb
mb MtCovector mb
b

instance
  ( HasGrad a,
    HasGrad b,
    MScalar b ~ MScalar a
  ) =>
  HasGrad (a, b)
  where
  type MScalar (a, b) = MScalar a
  type Grad (a, b) = (Grad a, Grad b)
  type Tang (a, b) = (Tang a, Tang b)
  type Metric (a, b) = (Metric a, Metric b)

instance
  ( MetricTensor ma,
    MetricTensor mb,
    MetricTensor mc,
    Scalar ma ~ Scalar mb,
    Scalar ma ~ Scalar mc
  ) =>
  MetricTensor (ma, mb, mc)
  where
  type MtVector (ma, mb, mc) = (MtVector ma, MtVector mb, MtVector mc)
  type MtCovector (ma, mb, mc) = (MtCovector ma, MtCovector mb, MtCovector mc)
  evalMetric :: (ma, mb, mc) -> MtCovector (ma, mb, mc) -> MtVector (ma, mb, mc)
evalMetric (ma
ma, mb
mb, mc
mc) (a, b, c) = (ma -> MtCovector ma -> MtVector ma
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric ma
ma MtCovector ma
a, mb -> MtCovector mb -> MtVector mb
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric mb
mb MtCovector mb
b, mc -> MtCovector mc -> MtVector mc
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric mc
mc MtCovector mc
c)
  sqrNorm :: (ma, mb, mc) -> MtCovector (ma, mb, mc) -> Scalar (ma, mb, mc)
sqrNorm (ma
ma, mb
mb, mc
mc) (a, b, c) = ma -> MtCovector ma -> Scalar ma
forall g. MetricTensor g => g -> MtCovector g -> Scalar g
sqrNorm ma
ma MtCovector ma
a Scalar mb -> Scalar mb -> Scalar mb
forall v. AdditiveGroup v => v -> v -> v
^+^ mb -> MtCovector mb -> Scalar mb
forall g. MetricTensor g => g -> MtCovector g -> Scalar g
sqrNorm mb
mb MtCovector mb
b Scalar mb -> Scalar mb -> Scalar mb
forall v. AdditiveGroup v => v -> v -> v
^+^ mc -> MtCovector mc -> Scalar mc
forall g. MetricTensor g => g -> MtCovector g -> Scalar g
sqrNorm mc
mc MtCovector mc
c

instance
  ( HasGrad a,
    HasGrad b,
    HasGrad c,
    MScalar b ~ MScalar a,
    MScalar c ~ MScalar a
  ) =>
  HasGrad (a, b, c)
  where
  type MScalar (a, b, c) = MScalar a
  type Grad (a, b, c) = (Grad a, Grad b, Grad c)
  type Tang (a, b, c) = (Tang a, Tang b, Tang c)
  type Metric (a, b, c) = (Metric a, Metric b, Metric c)

instance Dual Float Float Float where
  evalGrad :: Float -> Float -> Float
evalGrad = Float -> Float -> Float
forall a. Num a => a -> a -> a
(*)

instance MetricTensor Float where
  type MtVector Float = Float
  type MtCovector Float = Float
  evalMetric :: Float -> MtCovector Float -> MtVector Float
evalMetric Float
m MtCovector Float
dv = Float
m Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
MtCovector Float
dv

instance HasGrad Float where
  type MScalar Float = Float
  type Grad Float = Float
  type Tang Float = Float
  type Metric Float = Float

instance Dual Double Double Double where
  evalGrad :: Double -> Double -> Double
evalGrad = Double -> Double -> Double
forall a. Num a => a -> a -> a
(*)

instance MetricTensor Double where
  type MtVector Double = Double
  type MtCovector Double = Double
  evalMetric :: Double -> MtCovector Double -> MtVector Double
evalMetric Double
m MtCovector Double
dv = Double
m Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
MtCovector Double
dv

instance HasGrad Double where
  type MScalar Double = Double
  type Grad Double = Double
  type Tang Double = Double
  type Metric Double = Double

newtype L2 v = L2 (Scalar v)
  deriving ((forall x. L2 v -> Rep (L2 v) x)
-> (forall x. Rep (L2 v) x -> L2 v) -> Generic (L2 v)
forall x. Rep (L2 v) x -> L2 v
forall x. L2 v -> Rep (L2 v) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall v x. Rep (L2 v) x -> L2 v
forall v x. L2 v -> Rep (L2 v) x
$cto :: forall v x. Rep (L2 v) x -> L2 v
$cfrom :: forall v x. L2 v -> Rep (L2 v) x
Generic)

instance AdditiveGroup (Scalar v) => AdditiveGroup (L2 v)

instance (AdditiveGroup (Scalar v), Num (Scalar v)) => VectorSpace (L2 v) where
  type Scalar (L2 v) = Scalar v
  Scalar (L2 v)
x *^ :: Scalar (L2 v) -> L2 v -> L2 v
*^ L2 Scalar v
y = Scalar v -> L2 v
forall v. Scalar v -> L2 v
L2 (Scalar v
Scalar (L2 v)
x Scalar v -> Scalar v -> Scalar v
forall a. Num a => a -> a -> a
* Scalar v
y)

instance (AdditiveGroup a, Num a, a ~ Scalar v, Dual a v v) => MetricTensor (L2 v) where
  type MtVector (L2 v) = v
  type MtCovector (L2 v) = v
  evalMetric :: L2 v -> MtCovector (L2 v) -> MtVector (L2 v)
evalMetric (L2 Scalar v
a) MtCovector (L2 v)
u = Scalar v
a Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^ v
MtCovector (L2 v)
u
  innerProduct :: L2 v -> MtCovector (L2 v) -> MtCovector (L2 v) -> Scalar (L2 v)
innerProduct (L2 Scalar v
a) MtCovector (L2 v)
x MtCovector (L2 v)
y = a
Scalar v
a a -> a -> a
forall a. Num a => a -> a -> a
* v -> v -> a
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad v
MtCovector (L2 v)
x v
MtCovector (L2 v)
y
  sqrNorm :: L2 v -> MtCovector (L2 v) -> Scalar (L2 v)
sqrNorm L2 v
g MtCovector (L2 v)
x = L2 v -> MtCovector (L2 v) -> MtCovector (L2 v) -> Scalar (L2 v)
forall g.
MetricTensor g =>
g -> MtCovector g -> MtCovector g -> Scalar g
innerProduct L2 v
g MtCovector (L2 v)
x MtCovector (L2 v)
x