{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
-- {-# LANGUAGE MultiParamTypeClasses #-}
-- {-# language TypeFamilies #-}
{-# OPTIONS_GHC -Wno-unused-imports -Wno-unused-top-binds #-}
module Numeric.AD.DelCont.Internal
  (rad1, rad2, grad,
   auto,
   rad1g, rad2g, radNg,
   op1Num, op2Num,
   op1, op2,
   AD0, AD, AD')
  where

import Control.Monad.ST (ST, runST)
import Data.Bifunctor (Bifunctor(..))
import Data.STRef (STRef, newSTRef, readSTRef, modifySTRef')

-- transformers
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Cont (ContT, shiftT, resetT, evalContT)

import Prelude hiding (read)


-- | Dual numbers
data D a da = D { D a da -> a
primal :: a, D a da -> da
dual :: da } deriving (Int -> D a da -> ShowS
[D a da] -> ShowS
D a da -> String
(Int -> D a da -> ShowS)
-> (D a da -> String) -> ([D a da] -> ShowS) -> Show (D a da)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a da. (Show a, Show da) => Int -> D a da -> ShowS
forall a da. (Show a, Show da) => [D a da] -> ShowS
forall a da. (Show a, Show da) => D a da -> String
showList :: [D a da] -> ShowS
$cshowList :: forall a da. (Show a, Show da) => [D a da] -> ShowS
show :: D a da -> String
$cshow :: forall a da. (Show a, Show da) => D a da -> String
showsPrec :: Int -> D a da -> ShowS
$cshowsPrec :: forall a da. (Show a, Show da) => Int -> D a da -> ShowS
Show, a -> D a b -> D a a
(a -> b) -> D a a -> D a b
(forall a b. (a -> b) -> D a a -> D a b)
-> (forall a b. a -> D a b -> D a a) -> Functor (D a)
forall a b. a -> D a b -> D a a
forall a b. (a -> b) -> D a a -> D a b
forall a a b. a -> D a b -> D a a
forall a a b. (a -> b) -> D a a -> D a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> D a b -> D a a
$c<$ :: forall a a b. a -> D a b -> D a a
fmap :: (a -> b) -> D a a -> D a b
$cfmap :: forall a a b. (a -> b) -> D a a -> D a b
Functor)
instance Eq a => Eq (D a da) where
  D a
x da
_ == :: D a da -> D a da -> Bool
== D a
y da
_ = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
instance Ord a => Ord (D a db) where
  compare :: D a db -> D a db -> Ordering
compare (D a
x db
_) (D a
y db
_) = a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
x a
y
instance Bifunctor D where
  bimap :: (a -> b) -> (c -> d) -> D a c -> D b d
bimap a -> b
f c -> d
g (D a
a c
b) = b -> d -> D b d
forall a da. a -> da -> D a da
D (a -> b
f a
a) (c -> d
g c
b)

-- | Modify the adjoint part of a 'D'
withD :: (da -> db) -> D a da -> D a db
withD :: (da -> db) -> D a da -> D a db
withD = (da -> db) -> D a da -> D a db
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second

-- | Differentiable variable
--
-- A (safely) mutable reference to a dual number
type DVar s a da = STRef s (D a da)
-- | Introduce a fresh DVar
var :: a -> da -> ST s (DVar s a da)
var :: a -> da -> ST s (DVar s a da)
var a
x da
dx = D a da -> ST s (DVar s a da)
forall a s. a -> ST s (STRef s a)
newSTRef (a -> da -> D a da
forall a da. a -> da -> D a da
D a
x da
dx)

-- | Lift a constant value into 'AD'
--
-- As one expects from a constant, its value will be used for computing the result, but it will be discarded when computing the sensitivities.
auto :: a -> AD s a da
auto :: a -> AD s a da
auto a
x = (forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) (DVar s a da)) -> AD s a da)
-> (forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall a b. (a -> b) -> a -> b
$ ST s (DVar s a da) -> ContT x (ST s) (DVar s a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (DVar s a da) -> ContT x (ST s) (DVar s a da))
-> ST s (DVar s a da) -> ContT x (ST s) (DVar s a da)
forall a b. (a -> b) -> a -> b
$ a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
forall a. HasCallStack => a
undefined


-- | Mutable references in the continuation monad
newtype AD0 s a = AD0 { AD0 s a -> forall x. ContT x (ST s) a
unAD0 :: forall x . ContT x (ST s) a } deriving (a -> AD0 s b -> AD0 s a
(a -> b) -> AD0 s a -> AD0 s b
(forall a b. (a -> b) -> AD0 s a -> AD0 s b)
-> (forall a b. a -> AD0 s b -> AD0 s a) -> Functor (AD0 s)
forall a b. a -> AD0 s b -> AD0 s a
forall a b. (a -> b) -> AD0 s a -> AD0 s b
forall s a b. a -> AD0 s b -> AD0 s a
forall s a b. (a -> b) -> AD0 s a -> AD0 s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> AD0 s b -> AD0 s a
$c<$ :: forall s a b. a -> AD0 s b -> AD0 s a
fmap :: (a -> b) -> AD0 s a -> AD0 s b
$cfmap :: forall s a b. (a -> b) -> AD0 s a -> AD0 s b
Functor)
instance Applicative (AD0 s) where
  AD0 forall x. ContT x (ST s) (a -> b)
f <*> :: AD0 s (a -> b) -> AD0 s a -> AD0 s b
<*> AD0 forall x. ContT x (ST s) a
x = (forall x. ContT x (ST s) b) -> AD0 s b
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (ContT x (ST s) (a -> b)
forall x. ContT x (ST s) (a -> b)
f ContT x (ST s) (a -> b) -> ContT x (ST s) a -> ContT x (ST s) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ContT x (ST s) a
forall x. ContT x (ST s) a
x)
  pure :: a -> AD0 s a
pure a
x = (forall x. ContT x (ST s) a) -> AD0 s a
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) a) -> AD0 s a)
-> (forall x. ContT x (ST s) a) -> AD0 s a
forall a b. (a -> b) -> a -> b
$ a -> ContT x (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
-- instance Monad (AD s) where -- TODO

-- | A synonym of 'AD0' for the common case of returning a 'DVar' (which is a 'ST' computation that returns a dual number)
--
-- Here the @a@ and @da@ type parameters are respectively the /primal/ and /dual/ quantities tracked by the AD computation.
type  s a da = AD0 s (DVar s a da)
-- | Like 'AD' but the types of primal and dual coincide
type AD' s a = AD s a a



-- | Lift a unary function
--
-- This is a polymorphic combinator for tracking how primal and adjoint values are transformed by a function.
--
-- How does this work :
--
-- 1) Compute the function result and bind the function inputs to the adjoint updating function (the "pullback")
--
-- 2) Allocate a fresh STRef @rb@ with the function result and @zero@ adjoint part
--
-- 3) @rb@ is passed downstream as an argument to the continuation @k@, with the expectation that the STRef will be mutated
--
-- 4) Upon returning from the @k@ (bouncing from the boundary of @resetT@), the mutated STRef is read back in
--
-- 5) The adjoint part of the input variable is updated using @rb@ and the result of the continuation is returned.
op1_ :: db -- ^ zero
     -> (da -> da -> da) -- ^ plus
     -> (a -> (b, db -> da)) -- ^ returns : (function result, pullback)
     -> ContT x (ST s) (DVar s a da)
     -> ContT x (ST s) (DVar s b db)
op1_ :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
zeroa da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
ioa = do
  DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
  (D a
xa da
_) <- ST s (D a da) -> ContT x (ST s) (D a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D a da) -> ContT x (ST s) (D a da))
-> ST s (D a da) -> ContT x (ST s) (D a da)
forall a b. (a -> b) -> a -> b
$ DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
  let (b
xb, db -> da
g) = a -> (b, db -> da)
f a
xa -- 1)
  ((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db)
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT (((DVar s b db -> ST s x) -> ContT x (ST s) x)
 -> ContT x (ST s) (DVar s b db))
-> ((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ \ DVar s b db -> ST s x
k -> ST s x -> ContT x (ST s) x
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s x -> ContT x (ST s) x) -> ST s x -> ContT x (ST s) x
forall a b. (a -> b) -> a -> b
$ do
    DVar s b db
rb <- b -> db -> ST s (DVar s b db)
forall a da s. a -> da -> ST s (DVar s a da)
var b
xb db
zeroa -- 2)
    x
ry <- DVar s b db -> ST s x
k DVar s b db
rb -- 3)
    (D b
_ db
yd) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb -- 4)
    DVar s a da -> (D a da -> D a da) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra ((da -> da) -> D a da -> D a da
forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` db -> da
g db
yd)) -- 5)
    x -> ST s x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry


-- | Lift a unary function
--
-- The first two arguments constrain the types of the adjoint values of the output and input variable respectively, see 'op1Num' for an example.
--
-- The third argument is the most interesting: it specifies at once how to compute the function value and how to compute the sensitivity with respect to the function parameter.
--
-- Note : the type parameters are completely unconstrained.
op1 :: db -- ^ zero
    -> (da -> da -> da) -- ^ plus
    -> (a -> (b, db -> da)) -- ^ returns : (function result, pullback)
    -> AD s a da
    -> AD s b db
op1 :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
z da -> da -> da
plusa a -> (b, db -> da)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) = (forall x. ContT x (ST s) (DVar s b db)) -> AD s b db
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) (DVar s b db)) -> AD s b db)
-> (forall x. ContT x (ST s) (DVar s b db)) -> AD s b db
forall a b. (a -> b) -> a -> b
$ db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
z da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
forall x. ContT x (ST s) (DVar s a da)
ioa

-- | Helper for constructing unary functions that operate on Num instances (i.e. 'op1' specialized to Num)
op1Num :: (Num da, Num db) =>
          (a -> (b, db -> da)) -- ^ returns : (function result, pullback)
       -> AD s a da
       -> AD s b db
op1Num :: (a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num = db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
0 da -> da -> da
forall a. Num a => a -> a -> a
(+)

-- | Lift a binary function
op2_ :: dc -- ^ zero
     -> (da -> da -> da) -- ^ plus
     -> (db -> db -> db) -- ^ plus
     -> (a -> b -> (c, dc -> da, dc -> db)) -- ^ returns : (function result, pullbacks)
     -> ContT x (ST s) (DVar s a da)
     -> ContT x (ST s) (DVar s b db)
     -> ContT x (ST s) (DVar s c dc)
op2_ :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
zeroa da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
iob = do
  DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
  DVar s b db
rb <- ContT x (ST s) (DVar s b db)
iob
  (D a
xa da
_) <- ST s (D a da) -> ContT x (ST s) (D a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D a da) -> ContT x (ST s) (D a da))
-> ST s (D a da) -> ContT x (ST s) (D a da)
forall a b. (a -> b) -> a -> b
$ DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
  (D b
xb db
_) <- ST s (D b db) -> ContT x (ST s) (D b db)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D b db) -> ContT x (ST s) (D b db))
-> ST s (D b db) -> ContT x (ST s) (D b db)
forall a b. (a -> b) -> a -> b
$ DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
  let (c
xc, dc -> da
g, dc -> db
h) = a -> b -> (c, dc -> da, dc -> db)
f a
xa b
xb
  ((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc)
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT (((DVar s c dc -> ST s x) -> ContT x (ST s) x)
 -> ContT x (ST s) (DVar s c dc))
-> ((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc)
forall a b. (a -> b) -> a -> b
$ \ DVar s c dc -> ST s x
k -> ST s x -> ContT x (ST s) x
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s x -> ContT x (ST s) x) -> ST s x -> ContT x (ST s) x
forall a b. (a -> b) -> a -> b
$ do
    DVar s c dc
rc <- c -> dc -> ST s (DVar s c dc)
forall a da s. a -> da -> ST s (DVar s a da)
var c
xc dc
zeroa
    x
ry <- DVar s c dc -> ST s x
k DVar s c dc
rc
    (D c
_ dc
yd) <- DVar s c dc -> ST s (D c dc)
forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
rc
    DVar s a da -> (D a da -> D a da) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra ((da -> da) -> D a da -> D a da
forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` dc -> da
g dc
yd))
    DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
rb ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (\db
rdb0 -> db
rdb0 db -> db -> db
`plusb` dc -> db
h dc
yd))
    x -> ST s x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry

-- | Lift a binary function
--
-- See 'op1' for more information.
op2 :: dc -- ^ zero
    -> (da -> da -> da) -- ^ plus
    -> (db -> db -> db) -- ^ plus
    -> (a -> b -> (c, dc -> da, dc -> db)) -- ^ returns : (function result, pullbacks)
    -> (AD s a da -> AD s b db -> AD s c dc)
op2 :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) (AD0 forall x. ContT x (ST s) (DVar s b db)
iob) = (forall x. ContT x (ST s) (DVar s c dc)) -> AD s c dc
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) (DVar s c dc)) -> AD s c dc)
-> (forall x. ContT x (ST s) (DVar s c dc)) -> AD s c dc
forall a b. (a -> b) -> a -> b
$ dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
forall x. ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
forall x. ContT x (ST s) (DVar s b db)
iob

-- | Helper for constructing binary functions that operate on Num instances (i.e. 'op2' specialized to Num)
op2Num :: (Num da, Num db, Num dc) =>
          (a -> b -> (c, dc -> da, dc -> db)) -- ^ returns : (function result, pullback)
       -> AD s a da
       -> AD s b db
       -> AD s c dc
op2Num :: (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num = dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
0 da -> da -> da
forall a. Num a => a -> a -> a
(+) db -> db -> db
forall a. Num a => a -> a -> a
(+)

-- | The numerical methods of (Num, Fractional, Floating etc.) can be read off their @backprop@ counterparts : https://hackage.haskell.org/package/backprop-0.2.6.4/docs/src/Numeric.Backprop.Op.html#%2A.
instance (Num a) => Num (AD s a a) where
  + :: AD s a a -> AD s a a -> AD s a a
(+) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
 -> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y, a -> a
forall a. a -> a
id, a -> a
forall a. a -> a
id)
  (-) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
 -> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y, a -> a
forall a. a -> a
id, a -> a
forall a. Num a => a -> a
negate)
  * :: AD s a a -> AD s a a -> AD s a a
(*) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
 -> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
y, (a
y a -> a -> a
forall a. Num a => a -> a -> a
*), (a
x a -> a -> a
forall a. Num a => a -> a -> a
*))
  fromInteger :: Integer -> AD s a a
fromInteger Integer
x = a -> AD s a a
forall a s da. a -> AD s a da
auto (Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
x)
  abs :: AD s a a -> AD s a a
abs = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Num a => a -> a
abs a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Num a => a -> a
signum a
x))
  signum :: AD s a a -> AD s a a
signum = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Num a => a -> a
signum a
x, a -> a -> a
forall a b. a -> b -> a
const a
0)

instance (Fractional a) => Fractional (AD s a a) where
  / :: AD s a a -> AD s a a -> AD s a a
(/) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
 -> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y), (\a
g -> -a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
xa -> a -> a
forall a. Fractional a => a -> a -> a
/(a
ya -> a -> a
forall a. Num a => a -> a -> a
*a
y) ))
  fromRational :: Rational -> AD s a a
fromRational Rational
x = a -> AD s a a
forall a s da. a -> AD s a da
auto (Rational -> a
forall a. Fractional a => Rational -> a
fromRational Rational
x)
  recip :: AD s a a -> AD s a a
recip = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Fractional a => a -> a
recip a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate)

instance Floating a => Floating (AD s a a) where
  pi :: AD s a a
pi = a -> AD s a a
forall a s da. a -> AD s a da
auto a
forall a. Floating a => a
pi
  exp :: AD s a a -> AD s a a
exp = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
exp a
x, (a -> a
forall a. Floating a => a -> a
exp a
x a -> a -> a
forall a. Num a => a -> a -> a
*))
  log :: AD s a a -> AD s a a
log = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
log a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
x))
  sqrt :: AD s a a -> AD s a a
sqrt = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sqrt a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
x)))
  logBase :: AD s a a -> AD s a a -> AD s a a
logBase = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
 -> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
                       let
                         dx :: a
dx = - a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
x a
y a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a -> a
forall a. Floating a => a -> a
log a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
x)
                       in ( a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
x a
y
                          , ( a -> a -> a
forall a. Num a => a -> a -> a
* a
dx)
                          , (a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
y a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
x))
                          )
  sin :: AD s a a -> AD s a a
sin = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sin a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cos a
x))
  cos :: AD s a a -> AD s a a
cos = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
cos a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* (-a -> a
forall a. Floating a => a -> a
sin a
x)))
  tan :: AD s a a -> AD s a a
tan = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
tan a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
cos a
xa -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
  asin :: AD s a a -> AD s a a
asin = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
asin a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)))
  acos :: AD s a a -> AD s a a
acos = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
acos a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate)
  atan :: AD s a a -> AD s a a
atan = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
atan a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)))
  sinh :: AD s a a -> AD s a a
sinh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sinh a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cosh a
x))
  cosh :: AD s a a -> AD s a a
cosh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
cosh a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sinh a
x))
  tanh :: AD s a a -> AD s a a
tanh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
tanh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
cosh a
xa -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
  asinh :: AD s a a -> AD s a a
asinh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
asinh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)))
  acosh :: AD s a a -> AD s a a
acosh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
acosh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1)))
  atanh :: AD s a a -> AD s a a
atanh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
atanh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)))

-- -- instance Eq a => Eq (AD s a da) where -- ??? likely impossible
-- -- instance Ord (AD s a da) where -- ??? see above



-- | Evaluate (forward mode) and differentiate (reverse mode) a unary function, without committing to a specific numeric typeclass
rad1g :: da -- ^ zero
      -> db -- ^ one
      -> (forall s . AD s a da -> AD s b db)
      -> a -- ^ function argument
      -> (b, da) -- ^ (result, adjoint)
rad1g :: da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g da
zeroa db
oneb forall s. AD s a da -> AD s b db
f a
x = (forall s. ST s (b, da)) -> (b, da)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (b, da)) -> (b, da))
-> (forall s. ST s (b, da)) -> (b, da)
forall a b. (a -> b) -> a -> b
$ do
  DVar s a da
xr <- a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
  DVar s b db
zr' <- ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall a b. (a -> b) -> a -> b
$
    ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s b db) (ST s) (DVar s b db)
 -> ContT (DVar s b db) (ST s) (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ do
      let
        z :: AD s b db
z = AD s a da -> AD s b db
forall s. AD s a da -> AD s b db
f ((forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (DVar s a da -> ContT x (ST s) (DVar s a da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr))
      DVar s b db
zr <- AD s b db -> forall x. ContT x (ST s) (DVar s b db)
forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s b db
z
      ST s () -> ContT (DVar s b db) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s b db) (ST s) ())
-> ST s () -> ContT (DVar s b db) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (db -> db -> db
forall a b. a -> b -> a
const db
oneb))
      DVar s b db -> ContT (DVar s b db) (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
  (D b
z db
_) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
  (D a
_ da
x_bar) <- DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
  (b, da) -> ST s (b, da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, da
x_bar)



-- | Evaluate (forward mode) and differentiate (reverse mode) a binary function, without committing to a specific numeric typeclass
rad2g :: da -- ^ zero
      -> db -- ^ zero
      -> dc -- ^ one
      -> (forall s . AD s a da -> AD s b db -> AD s c dc)
      -> a -> b
      -> (c, (da, db)) -- ^ (result, adjoints)
rad2g :: da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g da
zeroa db
zerob dc
onec forall s. AD s a da -> AD s b db -> AD s c dc
f a
x b
y = (forall s. ST s (c, (da, db))) -> (c, (da, db))
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (c, (da, db))) -> (c, (da, db)))
-> (forall s. ST s (c, (da, db))) -> (c, (da, db))
forall a b. (a -> b) -> a -> b
$ do
  DVar s a da
xr <- a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
  DVar s b db
yr <- b -> db -> ST s (DVar s b db)
forall a da s. a -> da -> ST s (DVar s a da)
var b
y db
zerob
  DVar s c dc
zr' <- ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc))
-> ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc)
forall a b. (a -> b) -> a -> b
$
    ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s c dc) (ST s) (DVar s c dc)
 -> ContT (DVar s c dc) (ST s) (DVar s c dc))
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall a b. (a -> b) -> a -> b
$ do
      let
        z :: AD s c dc
z = AD s a da -> AD s b db -> AD s c dc
forall s. AD s a da -> AD s b db -> AD s c dc
f ((forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (DVar s a da -> ContT x (ST s) (DVar s a da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr)) ((forall x. ContT x (ST s) (DVar s b db)) -> AD s b db
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (DVar s b db -> ContT x (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
yr))
      DVar s c dc
zr <- AD s c dc -> forall x. ContT x (ST s) (DVar s c dc)
forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s c dc
z
      ST s () -> ContT (DVar s c dc) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s c dc) (ST s) ())
-> ST s () -> ContT (DVar s c dc) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s c dc -> (D c dc -> D c dc) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s c dc
zr ((dc -> dc) -> D c dc -> D c dc
forall b c a. (b -> c) -> D a b -> D a c
withD (dc -> dc -> dc
forall a b. a -> b -> a
const dc
onec))
      DVar s c dc -> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s c dc
zr
  (D c
z dc
_) <- DVar s c dc -> ST s (D c dc)
forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
zr'
  (D a
_ da
x_bar) <- DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
  (D b
_ db
y_bar) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
yr
  (c, (da, db)) -> ST s (c, (da, db))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (c
z, (da
x_bar, db
y_bar))

-- | Evaluate (forward mode) and differentiate (reverse mode) a function of a 'Traversable'
--
-- In linear algebra terms, this computes the gradient of a scalar function of vector argument
radNg :: Traversable t =>
         da -- ^ zero
      -> db -- ^ one
      -> (forall s . t (AD s a da) -> AD s b db)
      -> t a -- ^ argument vector
      -> (b, t da) -- ^ (result, gradient vector)
radNg :: da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg da
zeroa db
onea forall s. t (AD s a da) -> AD s b db
f t a
xs = (forall s. ST s (b, t da)) -> (b, t da)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (b, t da)) -> (b, t da))
-> (forall s. ST s (b, t da)) -> (b, t da)
forall a b. (a -> b) -> a -> b
$ do
  t (DVar s a da)
xrs <- (a -> ST s (DVar s a da)) -> t a -> ST s (t (DVar s a da))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
`var` da
zeroa) t a
xs
  DVar s b db
zr' <- ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall a b. (a -> b) -> a -> b
$
    ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s b db) (ST s) (DVar s b db)
 -> ContT (DVar s b db) (ST s) (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ do
      let
        (AD0 forall x. ContT x (ST s) (DVar s b db)
z) = t (AD s a da) -> AD0 s (DVar s b db)
forall s. t (AD s a da) -> AD s b db
f ((DVar s a da -> AD s a da) -> t (DVar s a da) -> t (AD s a da)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DVar s a da -> AD s a da
forall (f :: * -> *) a. Applicative f => a -> f a
pure t (DVar s a da)
xrs)
      DVar s b db
zr <- ContT (DVar s b db) (ST s) (DVar s b db)
forall x. ContT x (ST s) (DVar s b db)
z
      ST s () -> ContT (DVar s b db) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s b db) (ST s) ())
-> ST s () -> ContT (DVar s b db) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (db -> db -> db
forall a b. a -> b -> a
const db
onea))
      DVar s b db -> ContT (DVar s b db) (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
  (D b
z db
_) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
  t (D a da)
xs_bar <- (DVar s a da -> ST s (D a da))
-> t (DVar s a da) -> ST s (t (D a da))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef t (DVar s a da)
xrs
  let xs_bar_d :: t da
xs_bar_d = D a da -> da
forall a da. D a da -> da
dual (D a da -> da) -> t (D a da) -> t da
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t (D a da)
xs_bar
  (b, t da) -> ST s (b, t da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, t da
xs_bar_d)


-- | Evaluate (forward mode) and differentiate (reverse mode) a unary function
--
-- >>> rad1 (\x -> x * x) 1
-- (1, 2)
rad1 :: (Num a, Num b) =>
        (forall s . AD' s a -> AD' s b) -- ^ function to be differentiated
     -> a -- ^ function argument
     -> (b, a) -- ^ (result, adjoint)
rad1 :: (forall s. AD' s a -> AD' s b) -> a -> (b, a)
rad1 = a -> b -> (forall s. AD' s a -> AD' s b) -> a -> (b, a)
forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g a
0 b
1

-- | Evaluate (forward mode) and differentiate (reverse mode) a binary function
--
-- >>> rad2 (\x y -> x + y + y) 1 1
-- (1,2)
--
-- >>> rad2 (\x y -> (x + y) * x) 3 2
-- (15,(8,3))
rad2 :: (Num a, Num b, Num c) =>
        (forall s . AD' s a -> AD' s b -> AD' s c) -- ^ function to be differentiated
     -> a
     -> b
     -> (c, (a, b)) -- ^ (result, adjoints)
rad2 :: (forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
rad2 = a
-> b
-> c
-> (forall s. AD' s a -> AD' s b -> AD' s c)
-> a
-> b
-> (c, (a, b))
forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g a
0 b
0 c
1

-- | Evaluate (forward mode) and differentiate (reverse mode) a function of a 'Traversable'
--
-- In linear algebra terms, this computes the gradient of a scalar function of vector argument
--
--
-- @
-- sqNorm :: Num a => [a] -> a
-- sqNorm xs = sum $ zipWith (*) xs xs
--
-- p :: [Double]
-- p = [4.1, 2]
-- @
--
-- >>> grad sqNorm p
-- (20.81,[8.2,4.0])
grad :: (Traversable t, Num a, Num b) =>
        (forall s . t (AD' s a) -> AD' s b)
     -> t a -- ^ argument vector
     -> (b, t a) -- ^ (result, gradient vector)
grad :: (forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
grad = a -> b -> (forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
forall (t :: * -> *) da db a b.
Traversable t =>
da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg a
0 b
1


-- ======================== EXPERIMENTAL ==========================

data Backprop a da = Backprop {
  Backprop a da -> a -> da
zero :: a -> da
  , Backprop a da -> da -> da
one :: da -> da
  , Backprop a da -> da -> da -> da
add :: da -> da -> da
                              }

bpNum :: (Num a, Num da) => Backprop a da
bpNum :: Backprop a da
bpNum = (a -> da) -> (da -> da) -> (da -> da -> da) -> Backprop a da
forall a da.
(a -> da) -> (da -> da) -> (da -> da -> da) -> Backprop a da
Backprop a -> da
forall da a. Num da => a -> da
zeroNum da -> da
forall da a. Num da => a -> da
oneNum da -> da -> da
forall a. Num a => a -> a -> a
addNum

-- | backprop typeclass, adapted from https://hackage.haskell.org/package/backprop-0.2.6.4/docs/src/Numeric.Backprop.Class.html
--
-- we use two type parameters to keep the distinction between primal and dual variables

-- class Backprop a da where
--   zero :: a -> da
--   one :: proxy a -> da -> da
--   add :: proxy a -> da -> da -> da

-- | 'zero' for instances of 'Num'. lazy in its argument.
zeroNum :: Num da => a -> da
zeroNum :: a -> da
zeroNum a
_ = da
0
{-# INLINE zeroNum #-}

-- | 'add' for instances of 'Num'.
addNum :: Num da => da -> da -> da
addNum :: da -> da -> da
addNum = da -> da -> da
forall a. Num a => a -> a -> a
(+)
{-# INLINE addNum #-}

-- | 'one' for instances of 'Num'. lazy in its argument.
oneNum :: Num da => a -> da
oneNum :: a -> da
oneNum a
_ = da
1
{-# INLINE oneNum #-}


-- rad1BP :: Backprop a da
--        -> Backprop b db
--        -> (forall s . AD s a da -> AD s b db)
--        -> a -- ^ function argument
--        -> (b, da) -- ^ (result, adjoint)
-- rad1BP bpa bpb f x = runST $ do
--   xr <- var x (zero bpa x)
--   zr' <- evalContT $
--     resetT $ do
--       let
--         z = f (AD (pure xr))
--       zr <- unAD z
--       lift $ modifySTRef' zr (withD $ one bpb)
--       pure zr
--   (D z _) <- readSTRef zr'
--   (D _ x_bar) <- readSTRef xr
--   pure (z, x_bar)

-- -- rad1BP :: (Backprop a da, Backprop b db)
-- --        => (forall s . AD s a da -> AD s b db)
-- --        -> a -- ^ function argument
-- --        -> (b, da) -- ^ (result, adjoint)
-- -- rad1BP f x = runST $ do
-- --   xr <- var x (zero x)
-- --   zr' <- evalContT $
-- --     resetT $ do
-- --       let
-- --         z = f (AD (pure xr))
-- --       zr <- unAD z
-- --       let
-- --         oneB :: forall b db . Backprop b db => db -> db -> db
-- --         oneB = one (Proxy :: Proxy db)
-- --       lift $ modifySTRef' zr (withD oneB)
-- --       pure zr
-- --   (D z _) <- readSTRef zr'
-- --   (D _ x_bar) <- readSTRef xr
-- --   pure (z, x_bar)





-- -- playground



-- -- product type (simplified version of vinyl's Rec)
-- data Rec :: [*] -> * where
--   RNil :: Rec '[]
--   (:*) :: !a -> !(Rec as) -> Rec (a ': as)

-- -- dual pairing 
-- class Dual a where
--   dual :: Num r => a -> (a -> r)

-- -- | Dual numbers DD (alternative take, using a type family for the first variation)
-- data DD a = Dd a (Adj a)
-- class Diff a where type Adj a :: *
-- instance Diff Double where type Adj Double = Double




-- data SDRec s as where
--   SDNil :: SDRec s '[]
--   (:&) :: DVar s a a -> !(SDRec s as) -> SDRec s (a ': as)