{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.BVar
  ( BVar (..),
    var,
    constant,
    backprop,
  )
where

import Data.AdditiveGroup (AdditiveGroup)
import Data.AffineSpace (AffineSpace ((.+^), (.-.)))
import qualified Data.AffineSpace as AffineSpace
import Data.VectorSpace
  ( AdditiveGroup (..),
    VectorSpace ((*^)),
  )
import qualified Data.VectorSpace as VectorSpace
import Downhill.Grad
  ( Dual (evalGrad),
    HasFullGrad,
    HasGrad (Grad, MScalar, Tang),
    HasGradAffine,
  )
import Downhill.Linear.BackGrad
  ( BackGrad (..),
    realNode,
  )
import qualified Downhill.Linear.Backprop as BP
import Downhill.Linear.Expr (BasicVector, Expr (ExprVar), FullVector)
import Downhill.Linear.Lift (lift2_dense)
import Prelude hiding (id, (.))

-- | Variable is a value paired with derivative.
data BVar r a = BVar
  { BVar r a -> a
bvarValue :: a,
    BVar r a -> BackGrad r (Grad a)
bvarGrad :: BackGrad r (Grad a)
  }

instance (AdditiveGroup b, HasFullGrad b) => AdditiveGroup (BVar r b) where
  zeroV :: BVar r b
zeroV = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
forall v. AdditiveGroup v => v
zeroV BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  negateV :: BVar r b -> BVar r b
negateV (BVar b
y0 BackGrad r (Grad b)
dy) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall v. AdditiveGroup v => v -> v
negateV b
y0) (BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v
negateV BackGrad r (Grad b)
dy)
  BVar b
y0 BackGrad r (Grad b)
dy ^-^ :: BVar r b -> BVar r b -> BVar r b
^-^ BVar b
z0 BackGrad r (Grad b)
dz = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
y0 b -> b -> b
forall v. AdditiveGroup v => v -> v -> v
^-^ b
z0) (BackGrad r (Grad b)
dy BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^-^ BackGrad r (Grad b)
dz)
  BVar b
y0 BackGrad r (Grad b)
dy ^+^ :: BVar r b -> BVar r b -> BVar r b
^+^ BVar b
z0 BackGrad r (Grad b)
dz = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
y0 b -> b -> b
forall v. AdditiveGroup v => v -> v -> v
^+^ b
z0) (BackGrad r (Grad b)
dy BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ BackGrad r (Grad b)
dz)

instance (Num b, HasFullGrad b, MScalar b ~ b) => Num (BVar r b) where
  (BVar b
f0 BackGrad r (Grad b)
df) + :: BVar r b -> BVar r b -> BVar r b
+ (BVar b
g0 BackGrad r (Grad b)
dg) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
f0 b -> b -> b
forall a. Num a => a -> a -> a
+ b
g0) (BackGrad r (Grad b)
df BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ BackGrad r (Grad b)
dg)
  (BVar b
f0 BackGrad r (Grad b)
df) - :: BVar r b -> BVar r b -> BVar r b
- (BVar b
g0 BackGrad r (Grad b)
dg) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
f0 b -> b -> b
forall a. Num a => a -> a -> a
- b
g0) (BackGrad r (Grad b)
df BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^-^ BackGrad r (Grad b)
dg)
  (BVar b
f0 BackGrad r (Grad b)
df) * :: BVar r b -> BVar r b -> BVar r b
* (BVar b
g0 BackGrad r (Grad b)
dg) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
f0 b -> b -> b
forall a. Num a => a -> a -> a
* b
g0) (b
Scalar (BackGrad r (Grad b))
f0 Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dg BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^+^ b
Scalar (BackGrad r (Grad b))
g0 Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
df)
  negate :: BVar r b -> BVar r b
negate (BVar b
f0 BackGrad r (Grad b)
df) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Num a => a -> a
negate b
f0) (BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v
negateV BackGrad r (Grad b)
df)
  abs :: BVar r b -> BVar r b
abs (BVar b
f0 BackGrad r (Grad b)
df) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Num a => a -> a
abs b
f0) (b -> b
forall a. Num a => a -> a
signum b
f0 Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
df) -- TODO: ineffiency: multiplication by 1
  signum :: BVar r b -> BVar r b
signum (BVar b
f0 BackGrad r (Grad b)
_) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Num a => a -> a
signum b
f0) BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  fromInteger :: Integer -> BVar r b
fromInteger Integer
x = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (Integer -> b
forall a. Num a => Integer -> a
fromInteger Integer
x) BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV

sqr :: Num a => a -> a
sqr :: a -> a
sqr a
x = a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
x

rsqrt :: Floating a => a -> a
rsqrt :: a -> a
rsqrt a
x = a -> a
forall a. Fractional a => a -> a
recip (a -> a
forall a. Floating a => a -> a
sqrt a
x)

instance (Fractional b, HasFullGrad b, MScalar b ~ b) => Fractional (BVar r b) where
  fromRational :: Rational -> BVar r b
fromRational Rational
x = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (Rational -> b
forall a. Fractional a => Rational -> a
fromRational Rational
x) BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  recip :: BVar r b -> BVar r b
recip (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Fractional a => a -> a
recip b
x) (b
Scalar (BackGrad r (Grad b))
df Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
    where
      df :: b
df = b -> b
forall a. Num a => a -> a
negate (b -> b
forall a. Fractional a => a -> a
recip (b -> b
forall a. Num a => a -> a
sqr b
x))
  BVar b
x BackGrad r (Grad b)
dx / :: BVar r b -> BVar r b -> BVar r b
/ BVar b
y BackGrad r (Grad b)
dy = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b
x b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
y) ((b -> b
forall a. Fractional a => a -> a
recip b
y Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx) BackGrad r (Grad b) -> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. AdditiveGroup v => v -> v -> v
^-^ ((b
x b -> b -> b
forall a. Fractional a => a -> a -> a
/ b -> b
forall a. Num a => a -> a
sqr b
y) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dy))

instance (Floating b, HasFullGrad b, MScalar b ~ b) => Floating (BVar r b) where
  pi :: BVar r b
pi = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar b
forall a. Floating a => a
pi BackGrad r (Grad b)
forall v. AdditiveGroup v => v
zeroV
  exp :: BVar r b -> BVar r b
exp (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
exp b
x) (b -> b
forall a. Floating a => a -> a
exp b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  log :: BVar r b -> BVar r b
log (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
log b
x) (b -> b
forall a. Fractional a => a -> a
recip b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  sin :: BVar r b -> BVar r b
sin (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
sin b
x) (b -> b
forall a. Floating a => a -> a
cos b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  cos :: BVar r b -> BVar r b
cos (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
cos b
x) (b -> b
forall a. Num a => a -> a
negate (b -> b
forall a. Floating a => a -> a
sin b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  asin :: BVar r b -> BVar r b
asin (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
asin b
x) (b -> b
forall a. Floating a => a -> a
rsqrt (b
1 b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  acos :: BVar r b -> BVar r b
acos (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
acos b
x) (b -> b
forall a. Num a => a -> a
negate (b -> b
forall a. Floating a => a -> a
rsqrt (b
1 b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
forall a. Num a => a -> a
sqr b
x)) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  atan :: BVar r b -> BVar r b
atan (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
atan b
x) (b -> b
forall a. Fractional a => a -> a
recip (b
1 b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  sinh :: BVar r b -> BVar r b
sinh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
sinh b
x) (b -> b
forall a. Floating a => a -> a
cosh b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  cosh :: BVar r b -> BVar r b
cosh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
cosh b
x) (b -> b
forall a. Floating a => a -> a
sinh b
x Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  asinh :: BVar r b -> BVar r b
asinh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
asinh b
x) (b -> b
forall a. Floating a => a -> a
rsqrt (b
1 b -> b -> b
forall a. Num a => a -> a -> a
+ b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  acosh :: BVar r b -> BVar r b
acosh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
acosh b
x) (b -> b
forall a. Floating a => a -> a
rsqrt (b -> b
forall a. Num a => a -> a
sqr b
x b -> b -> b
forall a. Num a => a -> a -> a
- b
1) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)
  atanh :: BVar r b -> BVar r b
atanh (BVar b
x BackGrad r (Grad b)
dx) = b -> BackGrad r (Grad b) -> BVar r b
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (b -> b
forall a. Floating a => a -> a
atanh b
x) (b -> b
forall a. Fractional a => a -> a
recip (b
1 b -> b -> b
forall a. Num a => a -> a -> a
- b -> b
forall a. Num a => a -> a
sqr b
x) Scalar (BackGrad r (Grad b))
-> BackGrad r (Grad b) -> BackGrad r (Grad b)
forall v. VectorSpace v => Scalar v -> v -> v
*^ BackGrad r (Grad b)
dx)

instance
  ( VectorSpace v,
    HasFullGrad v,
    Tang v ~ v,
    FullVector (MScalar v),
    Grad (MScalar v) ~ MScalar v
  ) =>
  VectorSpace (BVar r v)
  where
  type Scalar (BVar r v) = BVar r (MScalar v)
  BVar a da *^ :: Scalar (BVar r v) -> BVar r v -> BVar r v
*^ BVar v
v BackGrad r (Grad v)
dv = v -> BackGrad r (Grad v) -> BVar r v
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (Scalar v
MScalar v
a Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^ v
v) ((Grad v -> MScalar v)
-> (Grad v -> Grad v)
-> BackGrad r (MScalar v)
-> BackGrad r (Grad v)
-> BackGrad r (Grad v)
forall v a b r.
(BasicVector v, FullVector a, FullVector b) =>
(v -> a)
-> (v -> b) -> BackGrad r a -> BackGrad r b -> BackGrad r v
lift2_dense Grad v -> MScalar v
bpA Grad v -> Grad v
bpV BackGrad r (MScalar v)
BackGrad r (Grad (MScalar v))
da BackGrad r (Grad v)
dv)
    where
      bpA :: Grad v -> MScalar v
      bpA :: Grad v -> MScalar v
bpA Grad v
dz = Grad v -> v -> MScalar v
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad Grad v
dz v
v
      bpV :: Grad v -> Grad v
      bpV :: Grad v -> Grad v
bpV Grad v
dz = Scalar (Grad v)
MScalar v
a Scalar (Grad v) -> Grad v -> Grad v
forall v. VectorSpace v => Scalar v -> v -> v
*^ Grad v
dz

instance (HasFullGrad p, HasGradAffine p) => AffineSpace (BVar r p) where
  type Diff (BVar r p) = BVar r (Tang p)
  BVar p
y0 BackGrad r (Grad p)
dy .+^ :: BVar r p -> Diff (BVar r p) -> BVar r p
.+^ BVar z0 dz = p -> BackGrad r (Grad p) -> BVar r p
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (p
y0 p -> Diff p -> p
forall p. AffineSpace p => p -> Diff p -> p
.+^ Diff p
z0) (BackGrad r (Grad p)
dy BackGrad r (Grad p) -> BackGrad r (Grad p) -> BackGrad r (Grad p)
forall v. AdditiveGroup v => v -> v -> v
^+^ BackGrad r (Grad p)
BackGrad r (Grad (Diff p))
dz)
  BVar p
y0 BackGrad r (Grad p)
dy .-. :: BVar r p -> BVar r p -> Diff (BVar r p)
.-. BVar p
z0 BackGrad r (Grad p)
dz = Diff p -> BackGrad r (Grad (Diff p)) -> BVar r (Diff p)
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar (p
y0 p -> p -> Diff p
forall p. AffineSpace p => p -> p -> Diff p
.-. p
z0) (BackGrad r (Grad p)
dy BackGrad r (Grad p) -> BackGrad r (Grad p) -> BackGrad r (Grad p)
forall v. AdditiveGroup v => v -> v -> v
^-^ BackGrad r (Grad p)
dz)

-- | A variable with derivative of zero.
constant :: forall r a. FullVector (Grad a) => a -> BVar r a
constant :: a -> BVar r a
constant a
x = a -> BackGrad r (Grad a) -> BVar r a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x BackGrad r (Grad a)
forall v. AdditiveGroup v => v
zeroV

-- | A variable with identity derivative.
var :: a -> BVar (Grad a) a
var :: a -> BVar (Grad a) a
var a
x = a -> BackGrad (Grad a) (Grad a) -> BVar (Grad a) a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x (Expr (Grad a) (Grad a) -> BackGrad (Grad a) (Grad a)
forall a v. Expr a v -> BackGrad a v
realNode Expr (Grad a) (Grad a)
forall a. Expr a a
ExprVar)

--backprop :: forall a p. (HasGrad p, BasicVector a) => BVar a p -> GradBuilder p -> a
--backprop (BVar _y0 x) = BP.backprop x

-- | Reverse mode differentiation.
--
-- 
backprop :: forall r a. (HasGrad a, FullVector (Grad a), BasicVector r) => BVar r a -> Grad a -> r
backprop :: BVar r a -> Grad a -> r
backprop (BVar a
_y0 BackGrad r (Grad a)
x) = BackGrad r (Grad a) -> Grad a -> r
forall a v. (BasicVector a, FullVector v) => BackGrad a v -> v -> a
BP.backprop BackGrad r (Grad a)
x