{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.BVar.Num
  ( -- | Automatic differentiation for @Num@ hierarchy.
    --
    -- Polymorphic functions of type such as @Num a => a -> a@
    -- can't be differentiated directly, because 'backprop' needs some additional instances.
    -- 'AsNum' wrapper provides those instances.
    --
    -- @
    -- derivative :: (forall b. Floating b => b -> b) -> (forall a. Floating a => a -> a)
    -- derivative fun x0 = backpropNum (fun (var (AsNum x0)))
    -- @

    AsNum (..),
    NumBVar,
    numbvarValue,
    var,
    constant,
    backpropNum
  )
where

import Data.AffineSpace (AffineSpace (..))
import Data.Semigroup (Sum (Sum, getSum))
import Data.VectorSpace (AdditiveGroup (..), VectorSpace (..), zeroV)
import Downhill.BVar (BVar (bvarValue), backprop)
import qualified Downhill.BVar as BVar
import Downhill.Grad
  ( Dual (evalGrad),
    HasGrad (Grad, Tang)
  )
import Downhill.Linear.Expr (BasicVector (..))
import Downhill.Metric (MetricTensor (evalMetric))

-- | @AsNum a@ implements many instances in terms of @Num a@ instance.
newtype AsNum a = AsNum {forall a. AsNum a -> a
unAsNum :: a}
  deriving (Int -> AsNum a -> ShowS
forall a. Show a => Int -> AsNum a -> ShowS
forall a. Show a => [AsNum a] -> ShowS
forall a. Show a => AsNum a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AsNum a] -> ShowS
$cshowList :: forall a. Show a => [AsNum a] -> ShowS
show :: AsNum a -> String
$cshow :: forall a. Show a => AsNum a -> String
showsPrec :: Int -> AsNum a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> AsNum a -> ShowS
Show)
  deriving (Integer -> AsNum a
AsNum a -> AsNum a
AsNum a -> AsNum a -> AsNum a
forall a. Num a => Integer -> AsNum a
forall a. Num a => AsNum a -> AsNum a
forall a. Num a => AsNum a -> AsNum a -> AsNum a
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> AsNum a
$cfromInteger :: forall a. Num a => Integer -> AsNum a
signum :: AsNum a -> AsNum a
$csignum :: forall a. Num a => AsNum a -> AsNum a
abs :: AsNum a -> AsNum a
$cabs :: forall a. Num a => AsNum a -> AsNum a
negate :: AsNum a -> AsNum a
$cnegate :: forall a. Num a => AsNum a -> AsNum a
* :: AsNum a -> AsNum a -> AsNum a
$c* :: forall a. Num a => AsNum a -> AsNum a -> AsNum a
- :: AsNum a -> AsNum a -> AsNum a
$c- :: forall a. Num a => AsNum a -> AsNum a -> AsNum a
+ :: AsNum a -> AsNum a -> AsNum a
$c+ :: forall a. Num a => AsNum a -> AsNum a -> AsNum a
Num) via a
  deriving (Rational -> AsNum a
AsNum a -> AsNum a
AsNum a -> AsNum a -> AsNum a
forall {a}. Fractional a => Num (AsNum a)
forall a. Fractional a => Rational -> AsNum a
forall a. Fractional a => AsNum a -> AsNum a
forall a. Fractional a => AsNum a -> AsNum a -> AsNum a
forall a.
Num a
-> (a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
fromRational :: Rational -> AsNum a
$cfromRational :: forall a. Fractional a => Rational -> AsNum a
recip :: AsNum a -> AsNum a
$crecip :: forall a. Fractional a => AsNum a -> AsNum a
/ :: AsNum a -> AsNum a -> AsNum a
$c/ :: forall a. Fractional a => AsNum a -> AsNum a -> AsNum a
Fractional) via a
  deriving (AsNum a
AsNum a -> AsNum a
AsNum a -> AsNum a -> AsNum a
forall {a}. Floating a => Fractional (AsNum a)
forall a. Floating a => AsNum a
forall a. Floating a => AsNum a -> AsNum a
forall a. Floating a => AsNum a -> AsNum a -> AsNum a
forall a.
Fractional a
-> a
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> Floating a
log1mexp :: AsNum a -> AsNum a
$clog1mexp :: forall a. Floating a => AsNum a -> AsNum a
log1pexp :: AsNum a -> AsNum a
$clog1pexp :: forall a. Floating a => AsNum a -> AsNum a
expm1 :: AsNum a -> AsNum a
$cexpm1 :: forall a. Floating a => AsNum a -> AsNum a
log1p :: AsNum a -> AsNum a
$clog1p :: forall a. Floating a => AsNum a -> AsNum a
atanh :: AsNum a -> AsNum a
$catanh :: forall a. Floating a => AsNum a -> AsNum a
acosh :: AsNum a -> AsNum a
$cacosh :: forall a. Floating a => AsNum a -> AsNum a
asinh :: AsNum a -> AsNum a
$casinh :: forall a. Floating a => AsNum a -> AsNum a
tanh :: AsNum a -> AsNum a
$ctanh :: forall a. Floating a => AsNum a -> AsNum a
cosh :: AsNum a -> AsNum a
$ccosh :: forall a. Floating a => AsNum a -> AsNum a
sinh :: AsNum a -> AsNum a
$csinh :: forall a. Floating a => AsNum a -> AsNum a
atan :: AsNum a -> AsNum a
$catan :: forall a. Floating a => AsNum a -> AsNum a
acos :: AsNum a -> AsNum a
$cacos :: forall a. Floating a => AsNum a -> AsNum a
asin :: AsNum a -> AsNum a
$casin :: forall a. Floating a => AsNum a -> AsNum a
tan :: AsNum a -> AsNum a
$ctan :: forall a. Floating a => AsNum a -> AsNum a
cos :: AsNum a -> AsNum a
$ccos :: forall a. Floating a => AsNum a -> AsNum a
sin :: AsNum a -> AsNum a
$csin :: forall a. Floating a => AsNum a -> AsNum a
logBase :: AsNum a -> AsNum a -> AsNum a
$clogBase :: forall a. Floating a => AsNum a -> AsNum a -> AsNum a
** :: AsNum a -> AsNum a -> AsNum a
$c** :: forall a. Floating a => AsNum a -> AsNum a -> AsNum a
sqrt :: AsNum a -> AsNum a
$csqrt :: forall a. Floating a => AsNum a -> AsNum a
log :: AsNum a -> AsNum a
$clog :: forall a. Floating a => AsNum a -> AsNum a
exp :: AsNum a -> AsNum a
$cexp :: forall a. Floating a => AsNum a -> AsNum a
pi :: AsNum a
$cpi :: forall a. Floating a => AsNum a
Floating) via a

instance Num a => Dual (AsNum a) (AsNum a) where
  evalGrad :: AsNum a -> AsNum a -> Scalar (AsNum a)
evalGrad = forall a. Num a => a -> a -> a
(*)

instance Num a => HasGrad (AsNum a) where
  type Grad (AsNum a) = AsNum a
  type Tang (AsNum a) = AsNum a

instance Num a => MetricTensor (AsNum a) (AsNum a) where
  evalMetric :: AsNum a -> Grad (AsNum a) -> Tang (AsNum a)
evalMetric (AsNum a
m) (AsNum a
x) = forall a. a -> AsNum a
AsNum (a
m forall a. Num a => a -> a -> a
* a
x)

instance Num a => AdditiveGroup (AsNum a) where
  zeroV :: AsNum a
zeroV = AsNum a
0
  ^+^ :: AsNum a -> AsNum a -> AsNum a
(^+^) = forall a. Num a => a -> a -> a
(+)
  ^-^ :: AsNum a -> AsNum a -> AsNum a
(^-^) = (-)
  negateV :: AsNum a -> AsNum a
negateV = forall a. Num a => a -> a
negate

instance Num a => VectorSpace (AsNum a) where
  type Scalar (AsNum a) = AsNum a
  *^ :: Scalar (AsNum a) -> AsNum a -> AsNum a
(*^) = forall a. Num a => a -> a -> a
(*)

instance Num a => BasicVector (AsNum a) where
  type VecBuilder (AsNum a) = Sum a
  sumBuilder :: VecBuilder (AsNum a) -> AsNum a
sumBuilder = forall a. a -> AsNum a
AsNum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Sum a -> a
getSum
  identityBuilder :: AsNum a -> VecBuilder (AsNum a)
identityBuilder = forall a. a -> Sum a
Sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. AsNum a -> a
unAsNum

instance Num a => AffineSpace (AsNum a) where
  type Diff (AsNum a) = AsNum a
  AsNum a
x .-. :: AsNum a -> AsNum a -> Diff (AsNum a)
.-. AsNum a
y = forall a. a -> AsNum a
AsNum (a
x forall a. Num a => a -> a -> a
- a
y)
  AsNum a
x .+^ :: AsNum a -> Diff (AsNum a) -> AsNum a
.+^ AsNum a
y = forall a. a -> AsNum a
AsNum (a
x forall a. Num a => a -> a -> a
+ a
y)

type NumBVar a = BVar (AsNum a) (AsNum a)

constant :: forall a. Num a => a -> NumBVar a
constant :: forall a. Num a => a -> NumBVar a
constant = forall r a.
(BasicVector (Grad a), AdditiveGroup (Grad a)) =>
a -> BVar r a
BVar.constant @(AsNum a) @(AsNum a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> AsNum a
AsNum

var :: Num a => a -> NumBVar a
var :: forall a. Num a => a -> NumBVar a
var = forall a. a -> BVar (Grad a) a
BVar.var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> AsNum a
AsNum

backpropNum :: forall a. Num a => NumBVar a -> a
backpropNum :: forall a. Num a => NumBVar a -> a
backpropNum NumBVar a
x = forall a. AsNum a -> a
unAsNum forall a b. (a -> b) -> a -> b
$ forall r a. (HasGrad a, BasicVector r) => BVar r a -> Grad a -> r
backprop @(AsNum a) @(AsNum a) NumBVar a
x (forall a. a -> AsNum a
AsNum a
1)

numbvarValue :: NumBVar a -> a
numbvarValue :: forall a. NumBVar a -> a
numbvarValue = forall a. AsNum a -> a
unAsNum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall r a. BVar r a -> a
bvarValue