{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -Wno-unused-imports -Wno-unused-top-binds #-}
module Numeric.AD.DelCont.Internal
  (rad1, rad2, grad,
   auto,
   rad1g, rad2g, radNg,
   op1Num, op2Num,
   op1, op2,
   AD0, AD, AD')
  where

import Control.Monad.ST (ST, runST)
import Data.Bifunctor (Bifunctor(..))
import Data.STRef (STRef, newSTRef, readSTRef, modifySTRef')

-- transformers
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Cont (ContT, shiftT, resetT, evalContT)

import Prelude hiding (read)

-- | Dual numbers
data D a da = D { forall a da. D a da -> a
primal :: a, forall a da. D a da -> da
dual :: da } deriving (Int -> D a da -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a da. (Show a, Show da) => Int -> D a da -> ShowS
forall a da. (Show a, Show da) => [D a da] -> ShowS
forall a da. (Show a, Show da) => D a da -> String
showList :: [D a da] -> ShowS
$cshowList :: forall a da. (Show a, Show da) => [D a da] -> ShowS
show :: D a da -> String
$cshow :: forall a da. (Show a, Show da) => D a da -> String
showsPrec :: Int -> D a da -> ShowS
$cshowsPrec :: forall a da. (Show a, Show da) => Int -> D a da -> ShowS
Show, forall a b. a -> D a b -> D a a
forall a b. (a -> b) -> D a a -> D a b
forall a a b. a -> D a b -> D a a
forall a a b. (a -> b) -> D a a -> D a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> D a b -> D a a
$c<$ :: forall a a b. a -> D a b -> D a a
fmap :: forall a b. (a -> b) -> D a a -> D a b
$cfmap :: forall a a b. (a -> b) -> D a a -> D a b
Functor)
instance Eq a => Eq (D a da) where
  D a
x da
_ == :: D a da -> D a da -> Bool
== D a
y da
_ = a
x forall a. Eq a => a -> a -> Bool
== a
y
instance Ord a => Ord (D a db) where
  compare :: D a db -> D a db -> Ordering
compare (D a
x db
_) (D a
y db
_) = forall a. Ord a => a -> a -> Ordering
compare a
x a
y
instance Bifunctor D where
  bimap :: forall a b c d. (a -> b) -> (c -> d) -> D a c -> D b d
bimap a -> b
f c -> d
g (D a
a c
b) = forall a da. a -> da -> D a da
D (a -> b
f a
a) (c -> d
g c
b)

-- | Modify the adjoint part of a 'D'
withD :: (da -> db) -> D a da -> D a db
withD :: forall b c a. (b -> c) -> D a b -> D a c
withD = forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second

-- | Differentiable variable
--
-- A (safely) mutable reference to a dual number
type DVar s a da = STRef s (D a da)
-- | Introduce a fresh DVar
var :: a -> da -> ST s (DVar s a da)
var :: forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
dx = forall a s. a -> ST s (STRef s a)
newSTRef (forall a da. a -> da -> D a da
D a
x da
dx)

-- | Lift a constant value into 'AD'
--
-- As one expects from a constant, its value will be used for computing the result, but it will be discarded when computing the sensitivities.
auto :: a -> AD s a da
auto :: forall a s da. a -> AD s a da
auto a
x = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a da s. a -> da -> ST s (DVar s a da)
var a
x forall a. HasCallStack => a
undefined -- NB blows up with -XStrict
autoStrict :: a -> da -> AD0 s (DVar s a da)
autoStrict :: forall a da s. a -> da -> AD0 s (DVar s a da)
autoStrict a
x da
dx = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
dx

-- | Mutable references in the continuation monad
newtype AD0 s a = AD0 { forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 :: forall x . ContT x (ST s) a } deriving (forall a b. a -> AD0 s b -> AD0 s a
forall a b. (a -> b) -> AD0 s a -> AD0 s b
forall s a b. a -> AD0 s b -> AD0 s a
forall s a b. (a -> b) -> AD0 s a -> AD0 s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AD0 s b -> AD0 s a
$c<$ :: forall s a b. a -> AD0 s b -> AD0 s a
fmap :: forall a b. (a -> b) -> AD0 s a -> AD0 s b
$cfmap :: forall s a b. (a -> b) -> AD0 s a -> AD0 s b
Functor)
instance Applicative (AD0 s) where
  AD0 forall x. ContT x (ST s) (a -> b)
f <*> :: forall a b. AD0 s (a -> b) -> AD0 s a -> AD0 s b
<*> AD0 forall x. ContT x (ST s) a
x = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall x. ContT x (ST s) (a -> b)
f forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall x. ContT x (ST s) a
x)
  pure :: forall a. a -> AD0 s a
pure a
x = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
-- instance Monad (AD s) where -- TODO

-- | A synonym of 'AD0' for the common case of returning a 'DVar' (which is a 'ST' computation that returns a dual number)
--
-- Here the @a@ and @da@ type parameters are respectively the /primal/ and /dual/ quantities tracked by the AD computation.
type  s a da = AD0 s (DVar s a da)
-- | Like 'AD' but the types of primal and dual coincide
type AD' s a = AD s a a



-- | Lift a unary function
--
-- This is a polymorphic combinator for tracking how primal and adjoint values are transformed by a function.
--
-- How does this work :
--
-- 1) Compute the function result and bind the function inputs to the adjoint updating function (the "pullback")
--
-- 2) Allocate a fresh STRef @rb@ with the function result and @zero@ adjoint part
--
-- 3) @rb@ is passed downstream as an argument to the continuation @k@, with the expectation that the STRef will be mutated
--
-- 4) Upon returning from the @k@ (bouncing from the boundary of @resetT@), the mutated STRef is read back in
--
-- 5) The adjoint part of the input variable is updated using @rb@ and the result of the continuation is returned.
op1_ :: db -- ^ zero
     -> (da -> da -> da) -- ^ plus
     -> (a -> (b, db -> da)) -- ^ returns : (function result, pullback)
     -> ContT x (ST s) (DVar s a da)
     -> ContT x (ST s) (DVar s b db)
op1_ :: forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
zeroa da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
ioa = do
  DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
  (D a
xa da
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
  let (b
xb, db -> da
g) = a -> (b, db -> da)
f a
xa -- 1)
  forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT forall a b. (a -> b) -> a -> b
$ \ DVar s b db -> ST s x
k -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
    DVar s b db
rb <- forall a da s. a -> da -> ST s (DVar s a da)
var b
xb db
zeroa -- 2)
    x
ry <- DVar s b db -> ST s x
k DVar s b db
rb -- 3)
    (D b
_ db
yd) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb -- 4)
    forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra (forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` db -> da
g db
yd)) -- 5)
    forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry


-- | Lift a unary function
--
-- The first two arguments constrain the types of the adjoint values of the output and input variable respectively, see 'op1Num' for an example.
--
-- The third argument is the most interesting: it specifies at once how to compute the function value and how to compute the sensitivity with respect to the function parameter.
--
-- Note : the type parameters are completely unconstrained.
op1 :: db -- ^ zero
    -> (da -> da -> da) -- ^ plus
    -> (a -> (b, db -> da)) -- ^ returns : (function result, pullback)
    -> AD s a da
    -> AD s b db
op1 :: forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
z da -> da -> da
plusa a -> (b, db -> da)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
z da -> da -> da
plusa a -> (b, db -> da)
f forall x. ContT x (ST s) (DVar s a da)
ioa

-- | Helper for constructing unary functions that operate on Num instances (i.e. 'op1' specialized to Num)
op1Num :: (Num da, Num db) =>
          (a -> (b, db -> da)) -- ^ returns : (function result, pullback)
       -> AD s a da
       -> AD s b db
op1Num :: forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num = forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
0 forall a. Num a => a -> a -> a
(+)

-- | Lift a binary function
op2_ :: dc -- ^ zero
     -> (da -> da -> da) -- ^ plus
     -> (db -> db -> db) -- ^ plus
     -> (a -> b -> (c, dc -> da, dc -> db)) -- ^ returns : (function result, pullbacks)
     -> ContT x (ST s) (DVar s a da)
     -> ContT x (ST s) (DVar s b db)
     -> ContT x (ST s) (DVar s c dc)
op2_ :: forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
zeroa da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
iob = do
  DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
  DVar s b db
rb <- ContT x (ST s) (DVar s b db)
iob
  (D a
xa da
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
  (D b
xb db
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
  let (c
xc, dc -> da
g, dc -> db
h) = a -> b -> (c, dc -> da, dc -> db)
f a
xa b
xb
  forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT forall a b. (a -> b) -> a -> b
$ \ DVar s c dc -> ST s x
k -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
    DVar s c dc
rc <- forall a da s. a -> da -> ST s (DVar s a da)
var c
xc dc
zeroa
    x
ry <- DVar s c dc -> ST s x
k DVar s c dc
rc
    (D c
_ dc
yd) <- forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
rc
    forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra (forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` dc -> da
g dc
yd))
    forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
rb (forall b c a. (b -> c) -> D a b -> D a c
withD (\db
rdb0 -> db
rdb0 db -> db -> db
`plusb` dc -> db
h dc
yd))
    forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry

-- | Lift a binary function
--
-- See 'op1' for more information.
op2 :: dc -- ^ zero
    -> (da -> da -> da) -- ^ plus
    -> (db -> db -> db) -- ^ plus
    -> (a -> b -> (c, dc -> da, dc -> db)) -- ^ returns : (function result, pullbacks)
    -> (AD s a da -> AD s b db -> AD s c dc)
op2 :: forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) (AD0 forall x. ContT x (ST s) (DVar s b db)
iob) = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f forall x. ContT x (ST s) (DVar s a da)
ioa forall x. ContT x (ST s) (DVar s b db)
iob

-- | Helper for constructing binary functions that operate on Num instances (i.e. 'op2' specialized to Num)
op2Num :: (Num da, Num db, Num dc) =>
          (a -> b -> (c, dc -> da, dc -> db)) -- ^ returns : (function result, pullback)
       -> AD s a da
       -> AD s b db
       -> AD s c dc
op2Num :: forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num = forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
0 forall a. Num a => a -> a -> a
(+) forall a. Num a => a -> a -> a
(+)

-- | The numerical methods of (Num, Fractional, Floating etc.) can be read off their @backprop@ counterparts : https://hackage.haskell.org/package/backprop-0.2.6.4/docs/src/Numeric.Backprop.Op.html#%2A.
instance (Num a) => Num (AD s a a) where
  + :: AD s a a -> AD s a a -> AD s a a
(+) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Num a => a -> a -> a
+ a
y, forall a. a -> a
id, forall a. a -> a
id)
  (-) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Num a => a -> a -> a
- a
y, forall a. a -> a
id, forall a. Num a => a -> a
negate)
  * :: AD s a a -> AD s a a -> AD s a a
(*) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
xforall a. Num a => a -> a -> a
*a
y, (a
y forall a. Num a => a -> a -> a
*), (a
x forall a. Num a => a -> a -> a
*))
  fromInteger :: Integer -> AD s a a
fromInteger Integer
x = forall a s da. a -> AD s a da
auto (forall a. Num a => Integer -> a
fromInteger Integer
x)
  abs :: AD s a a -> AD s a a
abs = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Num a => a -> a
abs a
x, (forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
signum a
x))
  signum :: AD s a a -> AD s a a
signum = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Num a => a -> a
signum a
x, forall a b. a -> b -> a
const a
0)

instance (Fractional a) => Fractional (AD s a a) where
  / :: AD s a a -> AD s a a -> AD s a a
(/) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Fractional a => a -> a -> a
/ a
y, (forall a. Fractional a => a -> a -> a
/ a
y), (\a
g -> -a
gforall a. Num a => a -> a -> a
*a
xforall a. Fractional a => a -> a -> a
/(a
yforall a. Num a => a -> a -> a
*a
y) ))
  fromRational :: Rational -> AD s a a
fromRational Rational
x = forall a s da. a -> AD s a da
auto (forall a. Fractional a => Rational -> a
fromRational Rational
x)
  recip :: AD s a a -> AD s a a
recip = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Fractional a => a -> a
recip a
x, (forall a. Fractional a => a -> a -> a
/(a
xforall a. Num a => a -> a -> a
*a
x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate)

instance Floating a => Floating (AD s a a) where
  pi :: AD s a a
pi = forall a s da. a -> AD s a da
auto forall a. Floating a => a
pi
  exp :: AD s a a -> AD s a a
exp = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
exp a
x, (forall a. Floating a => a -> a
exp a
x forall a. Num a => a -> a -> a
*))
  log :: AD s a a -> AD s a a
log = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
log a
x, (forall a. Fractional a => a -> a -> a
/a
x))
  sqrt :: AD s a a -> AD s a a
sqrt = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
sqrt a
x, (forall a. Fractional a => a -> a -> a
/ (a
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt a
x)))
  ** :: AD s a a -> AD s a a -> AD s a a
(**) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Floating a => a -> a -> a
** a
y, (forall a. Num a => a -> a -> a
* (a
y forall a. Num a => a -> a -> a
* a
x forall a. Floating a => a -> a -> a
** (a
y forall a. Num a => a -> a -> a
- a
1))), (forall a. Num a => a -> a -> a
* (a
x forall a. Floating a => a -> a -> a
** a
y forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
x)))
  logBase :: AD s a a -> AD s a a -> AD s a a
logBase = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
                       let
                         dx :: a
dx = - forall a. Floating a => a -> a -> a
logBase a
x a
y forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
log a
x forall a. Num a => a -> a -> a
* a
x)
                       in ( forall a. Floating a => a -> a -> a
logBase a
x a
y
                          , ( forall a. Num a => a -> a -> a
* a
dx)
                          , (forall a. Fractional a => a -> a -> a
/(a
y forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
x))
                          )
  sin :: AD s a a -> AD s a a
sin = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
sin a
x, (forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
cos a
x))
  cos :: AD s a a -> AD s a a
cos = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
cos a
x, (forall a. Num a => a -> a -> a
* (-forall a. Floating a => a -> a
sin a
x)))
  tan :: AD s a a -> AD s a a
tan = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
tan a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
cos a
xforall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
  asin :: AD s a a -> AD s a a
asin = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
asin a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt(a
1 forall a. Num a => a -> a -> a
- a
xforall a. Num a => a -> a -> a
*a
x)))
  acos :: AD s a a -> AD s a a
acos = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
acos a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (a
1 forall a. Num a => a -> a -> a
- a
xforall a. Num a => a -> a -> a
*a
x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate)
  atan :: AD s a a -> AD s a a
atan = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
atan a
x, (forall a. Fractional a => a -> a -> a
/ (a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
+ a
1)))
  sinh :: AD s a a -> AD s a a
sinh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
sinh a
x, (forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
cosh a
x))
  cosh :: AD s a a -> AD s a a
cosh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
cosh a
x, (forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sinh a
x))
  tanh :: AD s a a -> AD s a a
tanh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
tanh a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
cosh a
xforall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
  asinh :: AD s a a -> AD s a a
asinh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
asinh a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
+ a
1)))
  acosh :: AD s a a -> AD s a a
acosh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
acosh a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
- a
1)))
  atanh :: AD s a a -> AD s a a
atanh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
atanh a
x, (forall a. Fractional a => a -> a -> a
/ (a
1 forall a. Num a => a -> a -> a
- a
xforall a. Num a => a -> a -> a
*a
x)))

-- -- instance Eq a => Eq (AD s a da) where -- ??? likely impossible
-- -- instance Ord (AD s a da) where -- ??? see above



-- | Evaluate (forward mode) and differentiate (reverse mode) a unary function, without committing to a specific numeric typeclass
rad1g :: da -- ^ zero
      -> db -- ^ one
      -> (forall s . AD s a da -> AD s b db)
      -> a -- ^ function argument
      -> (b, da) -- ^ (result, adjoint)
rad1g :: forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g da
zeroa db
oneb forall s. AD s a da -> AD s b db
f a
x = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  DVar s a da
xr <- forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
  DVar s b db
zr' <- forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT forall a b. (a -> b) -> a -> b
$ do
      let
        z :: AD s b db
z = forall s. AD s a da -> AD s b db
f (forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr))
      DVar s b db
zr <- forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s b db
z
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr (forall b c a. (b -> c) -> D a b -> D a c
withD (forall a b. a -> b -> a
const db
oneb))
      forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
  (D b
z db
_) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
  (D a
_ da
x_bar) <- forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, da
x_bar)



-- | Evaluate (forward mode) and differentiate (reverse mode) a binary function, without committing to a specific numeric typeclass
rad2g :: da -- ^ zero
      -> db -- ^ zero
      -> dc -- ^ one
      -> (forall s . AD s a da -> AD s b db -> AD s c dc)
      -> a -> b
      -> (c, (da, db)) -- ^ (result, adjoints)
rad2g :: forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g da
zeroa db
zerob dc
onec forall s. AD s a da -> AD s b db -> AD s c dc
f a
x b
y = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  DVar s a da
xr <- forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
  DVar s b db
yr <- forall a da s. a -> da -> ST s (DVar s a da)
var b
y db
zerob
  DVar s c dc
zr' <- forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT forall a b. (a -> b) -> a -> b
$ do
      let
        z :: AD s c dc
z = forall s. AD s a da -> AD s b db -> AD s c dc
f (forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr)) (forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
yr))
      DVar s c dc
zr <- forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s c dc
z
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s c dc
zr (forall b c a. (b -> c) -> D a b -> D a c
withD (forall a b. a -> b -> a
const dc
onec))
      forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s c dc
zr
  (D c
z dc
_) <- forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
zr'
  (D a
_ da
x_bar) <- forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
  (D b
_ db
y_bar) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
yr
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (c
z, (da
x_bar, db
y_bar))

-- | Evaluate (forward mode) and differentiate (reverse mode) a function of a 'Traversable'
--
-- In linear algebra terms, this computes the gradient of a scalar function of vector argument
radNg :: Traversable t =>
         da -- ^ zero
      -> db -- ^ one
      -> (forall s . t (AD s a da) -> AD s b db)
      -> t a -- ^ argument vector
      -> (b, t da) -- ^ (result, gradient vector)
radNg :: forall (t :: * -> *) da db a b.
Traversable t =>
da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg da
zeroa db
onea forall s. t (AD s a da) -> AD s b db
f t a
xs = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
  t (DVar s a da)
xrs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a da s. a -> da -> ST s (DVar s a da)
`var` da
zeroa) t a
xs
  DVar s b db
zr' <- forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT forall a b. (a -> b) -> a -> b
$ do
      let
        (AD0 forall x. ContT x (ST s) (DVar s b db)
z) = forall s. t (AD s a da) -> AD s b db
f (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. Applicative f => a -> f a
pure t (DVar s a da)
xrs)
      DVar s b db
zr <- forall x. ContT x (ST s) (DVar s b db)
z
      forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr (forall b c a. (b -> c) -> D a b -> D a c
withD (forall a b. a -> b -> a
const db
onea))
      forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
  (D b
z db
_) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
  t (D a da)
xs_bar <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall s a. STRef s a -> ST s a
readSTRef t (DVar s a da)
xrs
  let xs_bar_d :: t da
xs_bar_d = forall a da. D a da -> da
dual forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t (D a da)
xs_bar
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, t da
xs_bar_d)


-- jacg zeroa onea f xs = runST $ do -- -- Jacobian TODO
--   xrs <- traverse (`var` zeroa) xs
--   zr' <- evalContT $
--     resetT $ do
--       let
--         zads = f (fmap pure xrs) -- traversable of AD results
--       for zads $ \zad -> do
--         zr <- zad
--         lift $ modifySTRef' zr (withD (const onea))
--         pure zr
--   undefined

for :: (Applicative f, Traversable t) => t a -> (a -> f b) -> f (t b)
for :: forall (f :: * -> *) (t :: * -> *) a b.
(Applicative f, Traversable t) =>
t a -> (a -> f b) -> f (t b)
for = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse


-- | Evaluate (forward mode) and differentiate (reverse mode) a unary function
--
-- >>> rad1 (\x -> x * x) 1
-- (1, 2)
rad1 :: (Num a, Num b) =>
        (forall s . AD' s a -> AD' s b) -- ^ function to be differentiated
     -> a -- ^ function argument
     -> (b, a) -- ^ (result, adjoint)
rad1 :: forall a b.
(Num a, Num b) =>
(forall s. AD' s a -> AD' s b) -> a -> (b, a)
rad1 = forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g a
0 b
1

-- | Evaluate (forward mode) and differentiate (reverse mode) a binary function
--
-- >>> rad2 (\x y -> x + y + y) 1 1
-- (1,2)
--
-- >>> rad2 (\x y -> (x + y) * x) 3 2
-- (15,(8,3))
rad2 :: (Num a, Num b, Num c) =>
        (forall s . AD' s a -> AD' s b -> AD' s c) -- ^ function to be differentiated
     -> a
     -> b
     -> (c, (a, b)) -- ^ (result, adjoints)
rad2 :: forall a b c.
(Num a, Num b, Num c) =>
(forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
rad2 = forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g a
0 b
0 c
1

-- | Evaluate (forward mode) and differentiate (reverse mode) a function of a 'Traversable'
--
-- In linear algebra terms, this computes the gradient of a scalar function of vector argument
--
--
-- @
-- sqNorm :: Num a => [a] -> a
-- sqNorm xs = sum $ zipWith (*) xs xs
--
-- p :: [Double]
-- p = [4.1, 2]
-- @
--
-- >>> grad sqNorm p
-- (20.81,[8.2,4.0])
grad :: (Traversable t, Num a, Num b) =>
        (forall s . t (AD' s a) -> AD' s b)
     -> t a -- ^ argument vector
     -> (b, t a) -- ^ (result, gradient vector)
grad :: forall (t :: * -> *) a b.
(Traversable t, Num a, Num b) =>
(forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
grad = forall (t :: * -> *) da db a b.
Traversable t =>
da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg a
0 b
1


-- ======================== EXPERIMENTAL ==========================

data Backprop a da = Backprop {
  forall a da. Backprop a da -> a -> da
zero :: a -> da
  , forall a da. Backprop a da -> da -> da
one :: da -> da
  , forall a da. Backprop a da -> da -> da -> da
add :: da -> da -> da
                              }

bpNum :: (Num a, Num da) => Backprop a da
bpNum :: forall a da. (Num a, Num da) => Backprop a da
bpNum = forall a da.
(a -> da) -> (da -> da) -> (da -> da -> da) -> Backprop a da
Backprop forall da a. Num da => a -> da
zeroNum forall da a. Num da => a -> da
oneNum forall a. Num a => a -> a -> a
addNum

-- | backprop typeclass, adapted from https://hackage.haskell.org/package/backprop-0.2.6.4/docs/src/Numeric.Backprop.Class.html
--
-- we use two type parameters to keep the distinction between primal and dual variables

-- class Backprop a da where
--   zero :: a -> da
--   one :: proxy a -> da -> da
--   add :: proxy a -> da -> da -> da

-- | 'zero' for instances of 'Num'. lazy in its argument.
zeroNum :: Num da => a -> da
zeroNum :: forall da a. Num da => a -> da
zeroNum a
_ = da
0
{-# INLINE zeroNum #-}

-- | 'add' for instances of 'Num'.
addNum :: Num da => da -> da -> da
addNum :: forall a. Num a => a -> a -> a
addNum = forall a. Num a => a -> a -> a
(+)
{-# INLINE addNum #-}

-- | 'one' for instances of 'Num'. lazy in its argument.
oneNum :: Num da => a -> da
oneNum :: forall da a. Num da => a -> da
oneNum a
_ = da
1
{-# INLINE oneNum #-}


-- rad1BP :: Backprop a da
--        -> Backprop b db
--        -> (forall s . AD s a da -> AD s b db)
--        -> a -- ^ function argument
--        -> (b, da) -- ^ (result, adjoint)
-- rad1BP bpa bpb f x = runST $ do
--   xr <- var x (zero bpa x)
--   zr' <- evalContT $
--     resetT $ do
--       let
--         z = f (AD (pure xr))
--       zr <- unAD z
--       lift $ modifySTRef' zr (withD $ one bpb)
--       pure zr
--   (D z _) <- readSTRef zr'
--   (D _ x_bar) <- readSTRef xr
--   pure (z, x_bar)

-- -- rad1BP :: (Backprop a da, Backprop b db)
-- --        => (forall s . AD s a da -> AD s b db)
-- --        -> a -- ^ function argument
-- --        -> (b, da) -- ^ (result, adjoint)
-- -- rad1BP f x = runST $ do
-- --   xr <- var x (zero x)
-- --   zr' <- evalContT $
-- --     resetT $ do
-- --       let
-- --         z = f (AD (pure xr))
-- --       zr <- unAD z
-- --       let
-- --         oneB :: forall b db . Backprop b db => db -> db -> db
-- --         oneB = one (Proxy :: Proxy db)
-- --       lift $ modifySTRef' zr (withD oneB)
-- --       pure zr
-- --   (D z _) <- readSTRef zr'
-- --   (D _ x_bar) <- readSTRef xr
-- --   pure (z, x_bar)





-- -- playground



-- -- product type (simplified version of vinyl's Rec)
-- data Rec :: [*] -> * where
--   RNil :: Rec '[]
--   (:*) :: !a -> !(Rec as) -> Rec (a ': as)

-- -- dual pairing 
-- class Dual a where
--   dual :: Num r => a -> (a -> r)

-- -- | Dual numbers DD (alternative take, using a type family for the first variation)
-- data DD a = Dd a (Adj a)
-- class Diff a where type Adj a :: *
-- instance Diff Double where type Adj Double = Double




-- data SDRec s as where
--   SDNil :: SDRec s '[]
--   (:&) :: DVar s a a -> !(SDRec s as) -> SDRec s (a ': as)