{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -Wno-unused-imports -Wno-unused-top-binds #-}
module Numeric.AD.DelCont.Internal
(rad1, rad2, grad,
auto,
rad1g, rad2g, radNg,
op1Num, op2Num,
op1, op2,
AD0, AD, AD')
where
import Control.Monad.ST (ST, runST)
import Data.Bifunctor (Bifunctor(..))
import Data.STRef (STRef, newSTRef, readSTRef, modifySTRef')
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Cont (ContT, shiftT, resetT, evalContT)
import Prelude hiding (read)
data D a da = D { D a da -> a
primal :: a, D a da -> da
dual :: da } deriving (Int -> D a da -> ShowS
[D a da] -> ShowS
D a da -> String
(Int -> D a da -> ShowS)
-> (D a da -> String) -> ([D a da] -> ShowS) -> Show (D a da)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a da. (Show a, Show da) => Int -> D a da -> ShowS
forall a da. (Show a, Show da) => [D a da] -> ShowS
forall a da. (Show a, Show da) => D a da -> String
showList :: [D a da] -> ShowS
$cshowList :: forall a da. (Show a, Show da) => [D a da] -> ShowS
show :: D a da -> String
$cshow :: forall a da. (Show a, Show da) => D a da -> String
showsPrec :: Int -> D a da -> ShowS
$cshowsPrec :: forall a da. (Show a, Show da) => Int -> D a da -> ShowS
Show, a -> D a b -> D a a
(a -> b) -> D a a -> D a b
(forall a b. (a -> b) -> D a a -> D a b)
-> (forall a b. a -> D a b -> D a a) -> Functor (D a)
forall a b. a -> D a b -> D a a
forall a b. (a -> b) -> D a a -> D a b
forall a a b. a -> D a b -> D a a
forall a a b. (a -> b) -> D a a -> D a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> D a b -> D a a
$c<$ :: forall a a b. a -> D a b -> D a a
fmap :: (a -> b) -> D a a -> D a b
$cfmap :: forall a a b. (a -> b) -> D a a -> D a b
Functor)
instance Eq a => Eq (D a da) where
D a
x da
_ == :: D a da -> D a da -> Bool
== D a
y da
_ = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
instance Ord a => Ord (D a db) where
compare :: D a db -> D a db -> Ordering
compare (D a
x db
_) (D a
y db
_) = a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
x a
y
instance Bifunctor D where
bimap :: (a -> b) -> (c -> d) -> D a c -> D b d
bimap a -> b
f c -> d
g (D a
a c
b) = b -> d -> D b d
forall a da. a -> da -> D a da
D (a -> b
f a
a) (c -> d
g c
b)
withD :: (da -> db) -> D a da -> D a db
withD :: (da -> db) -> D a da -> D a db
withD = (da -> db) -> D a da -> D a db
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second
type DVar s a da = STRef s (D a da)
var :: a -> da -> ST s (DVar s a da)
var :: a -> da -> ST s (DVar s a da)
var a
x da
dx = D a da -> ST s (DVar s a da)
forall a s. a -> ST s (STRef s a)
newSTRef (a -> da -> D a da
forall a da. a -> da -> D a da
D a
x da
dx)
auto :: a -> AD s a da
auto :: a -> AD s a da
auto a
x = (forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) (DVar s a da)) -> AD s a da)
-> (forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall a b. (a -> b) -> a -> b
$ ST s (DVar s a da) -> ContT x (ST s) (DVar s a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (DVar s a da) -> ContT x (ST s) (DVar s a da))
-> ST s (DVar s a da) -> ContT x (ST s) (DVar s a da)
forall a b. (a -> b) -> a -> b
$ a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
forall a. HasCallStack => a
undefined
newtype AD0 s a = AD0 { AD0 s a -> forall x. ContT x (ST s) a
unAD0 :: forall x . ContT x (ST s) a } deriving (a -> AD0 s b -> AD0 s a
(a -> b) -> AD0 s a -> AD0 s b
(forall a b. (a -> b) -> AD0 s a -> AD0 s b)
-> (forall a b. a -> AD0 s b -> AD0 s a) -> Functor (AD0 s)
forall a b. a -> AD0 s b -> AD0 s a
forall a b. (a -> b) -> AD0 s a -> AD0 s b
forall s a b. a -> AD0 s b -> AD0 s a
forall s a b. (a -> b) -> AD0 s a -> AD0 s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> AD0 s b -> AD0 s a
$c<$ :: forall s a b. a -> AD0 s b -> AD0 s a
fmap :: (a -> b) -> AD0 s a -> AD0 s b
$cfmap :: forall s a b. (a -> b) -> AD0 s a -> AD0 s b
Functor)
instance Applicative (AD0 s) where
AD0 forall x. ContT x (ST s) (a -> b)
f <*> :: AD0 s (a -> b) -> AD0 s a -> AD0 s b
<*> AD0 forall x. ContT x (ST s) a
x = (forall x. ContT x (ST s) b) -> AD0 s b
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (ContT x (ST s) (a -> b)
forall x. ContT x (ST s) (a -> b)
f ContT x (ST s) (a -> b) -> ContT x (ST s) a -> ContT x (ST s) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ContT x (ST s) a
forall x. ContT x (ST s) a
x)
pure :: a -> AD0 s a
pure a
x = (forall x. ContT x (ST s) a) -> AD0 s a
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) a) -> AD0 s a)
-> (forall x. ContT x (ST s) a) -> AD0 s a
forall a b. (a -> b) -> a -> b
$ a -> ContT x (ST s) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
type AD s a da = AD0 s (DVar s a da)
type AD' s a = AD s a a
op1_ :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
zeroa da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
ioa = do
DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
(D a
xa da
_) <- ST s (D a da) -> ContT x (ST s) (D a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D a da) -> ContT x (ST s) (D a da))
-> ST s (D a da) -> ContT x (ST s) (D a da)
forall a b. (a -> b) -> a -> b
$ DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
let (b
xb, db -> da
g) = a -> (b, db -> da)
f a
xa
((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db)
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT (((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db))
-> ((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ \ DVar s b db -> ST s x
k -> ST s x -> ContT x (ST s) x
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s x -> ContT x (ST s) x) -> ST s x -> ContT x (ST s) x
forall a b. (a -> b) -> a -> b
$ do
DVar s b db
rb <- b -> db -> ST s (DVar s b db)
forall a da s. a -> da -> ST s (DVar s a da)
var b
xb db
zeroa
x
ry <- DVar s b db -> ST s x
k DVar s b db
rb
(D b
_ db
yd) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
DVar s a da -> (D a da -> D a da) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra ((da -> da) -> D a da -> D a da
forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` db -> da
g db
yd))
x -> ST s x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry
op1 :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
z da -> da -> da
plusa a -> (b, db -> da)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) = (forall x. ContT x (ST s) (DVar s b db)) -> AD s b db
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) (DVar s b db)) -> AD s b db)
-> (forall x. ContT x (ST s) (DVar s b db)) -> AD s b db
forall a b. (a -> b) -> a -> b
$ db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
z da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
forall x. ContT x (ST s) (DVar s a da)
ioa
op1Num :: (Num da, Num db) =>
(a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1Num :: (a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num = db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
0 da -> da -> da
forall a. Num a => a -> a -> a
(+)
op2_ :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
zeroa da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
iob = do
DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
DVar s b db
rb <- ContT x (ST s) (DVar s b db)
iob
(D a
xa da
_) <- ST s (D a da) -> ContT x (ST s) (D a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D a da) -> ContT x (ST s) (D a da))
-> ST s (D a da) -> ContT x (ST s) (D a da)
forall a b. (a -> b) -> a -> b
$ DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
(D b
xb db
_) <- ST s (D b db) -> ContT x (ST s) (D b db)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D b db) -> ContT x (ST s) (D b db))
-> ST s (D b db) -> ContT x (ST s) (D b db)
forall a b. (a -> b) -> a -> b
$ DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
let (c
xc, dc -> da
g, dc -> db
h) = a -> b -> (c, dc -> da, dc -> db)
f a
xa b
xb
((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc)
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT (((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc))
-> ((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc)
forall a b. (a -> b) -> a -> b
$ \ DVar s c dc -> ST s x
k -> ST s x -> ContT x (ST s) x
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s x -> ContT x (ST s) x) -> ST s x -> ContT x (ST s) x
forall a b. (a -> b) -> a -> b
$ do
DVar s c dc
rc <- c -> dc -> ST s (DVar s c dc)
forall a da s. a -> da -> ST s (DVar s a da)
var c
xc dc
zeroa
x
ry <- DVar s c dc -> ST s x
k DVar s c dc
rc
(D c
_ dc
yd) <- DVar s c dc -> ST s (D c dc)
forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
rc
DVar s a da -> (D a da -> D a da) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra ((da -> da) -> D a da -> D a da
forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` dc -> da
g dc
yd))
DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
rb ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (\db
rdb0 -> db
rdb0 db -> db -> db
`plusb` dc -> db
h dc
yd))
x -> ST s x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry
op2 :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> (AD s a da -> AD s b db -> AD s c dc)
op2 :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) (AD0 forall x. ContT x (ST s) (DVar s b db)
iob) = (forall x. ContT x (ST s) (DVar s c dc)) -> AD s c dc
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 ((forall x. ContT x (ST s) (DVar s c dc)) -> AD s c dc)
-> (forall x. ContT x (ST s) (DVar s c dc)) -> AD s c dc
forall a b. (a -> b) -> a -> b
$ dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
forall x. ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
forall x. ContT x (ST s) (DVar s b db)
iob
op2Num :: (Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2Num :: (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num = dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
0 da -> da -> da
forall a. Num a => a -> a -> a
(+) db -> db -> db
forall a. Num a => a -> a -> a
(+)
instance (Num a) => Num (AD s a a) where
+ :: AD s a a -> AD s a a -> AD s a a
(+) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y, a -> a
forall a. a -> a
id, a -> a
forall a. a -> a
id)
(-) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y, a -> a
forall a. a -> a
id, a -> a
forall a. Num a => a -> a
negate)
* :: AD s a a -> AD s a a -> AD s a a
(*) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
y, (a
y a -> a -> a
forall a. Num a => a -> a -> a
*), (a
x a -> a -> a
forall a. Num a => a -> a -> a
*))
fromInteger :: Integer -> AD s a a
fromInteger Integer
x = a -> AD s a a
forall a s da. a -> AD s a da
auto (Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
x)
abs :: AD s a a -> AD s a a
abs = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Num a => a -> a
abs a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Num a => a -> a
signum a
x))
signum :: AD s a a -> AD s a a
signum = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Num a => a -> a
signum a
x, a -> a -> a
forall a b. a -> b -> a
const a
0)
instance (Fractional a) => Fractional (AD s a a) where
/ :: AD s a a -> AD s a a -> AD s a a
(/) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y), (\a
g -> -a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
xa -> a -> a
forall a. Fractional a => a -> a -> a
/(a
ya -> a -> a
forall a. Num a => a -> a -> a
*a
y) ))
fromRational :: Rational -> AD s a a
fromRational Rational
x = a -> AD s a a
forall a s da. a -> AD s a da
auto (Rational -> a
forall a. Fractional a => Rational -> a
fromRational Rational
x)
recip :: AD s a a -> AD s a a
recip = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Fractional a => a -> a
recip a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate)
instance Floating a => Floating (AD s a a) where
pi :: AD s a a
pi = a -> AD s a a
forall a s da. a -> AD s a da
auto a
forall a. Floating a => a
pi
exp :: AD s a a -> AD s a a
exp = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
exp a
x, (a -> a
forall a. Floating a => a -> a
exp a
x a -> a -> a
forall a. Num a => a -> a -> a
*))
log :: AD s a a -> AD s a a
log = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
log a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
x))
sqrt :: AD s a a -> AD s a a
sqrt = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sqrt a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
x)))
logBase :: AD s a a -> AD s a a -> AD s a a
logBase = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
let
dx :: a
dx = - a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
x a
y a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a -> a
forall a. Floating a => a -> a
log a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
x)
in ( a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
x a
y
, ( a -> a -> a
forall a. Num a => a -> a -> a
* a
dx)
, (a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
y a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
x))
)
sin :: AD s a a -> AD s a a
sin = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sin a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cos a
x))
cos :: AD s a a -> AD s a a
cos = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
cos a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* (-a -> a
forall a. Floating a => a -> a
sin a
x)))
tan :: AD s a a -> AD s a a
tan = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
tan a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
cos a
xa -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
asin :: AD s a a -> AD s a a
asin = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
asin a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)))
acos :: AD s a a -> AD s a a
acos = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
acos a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate)
atan :: AD s a a -> AD s a a
atan = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
atan a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)))
sinh :: AD s a a -> AD s a a
sinh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sinh a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cosh a
x))
cosh :: AD s a a -> AD s a a
cosh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
cosh a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sinh a
x))
tanh :: AD s a a -> AD s a a
tanh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
tanh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
cosh a
xa -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
asinh :: AD s a a -> AD s a a
asinh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
asinh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)))
acosh :: AD s a a -> AD s a a
acosh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
acosh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1)))
atanh :: AD s a a -> AD s a a
atanh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
atanh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)))
rad1g :: da
-> db
-> (forall s . AD s a da -> AD s b db)
-> a
-> (b, da)
rad1g :: da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g da
zeroa db
oneb forall s. AD s a da -> AD s b db
f a
x = (forall s. ST s (b, da)) -> (b, da)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (b, da)) -> (b, da))
-> (forall s. ST s (b, da)) -> (b, da)
forall a b. (a -> b) -> a -> b
$ do
DVar s a da
xr <- a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
DVar s b db
zr' <- ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall a b. (a -> b) -> a -> b
$
ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ do
let
z :: AD s b db
z = AD s a da -> AD s b db
forall s. AD s a da -> AD s b db
f ((forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (DVar s a da -> ContT x (ST s) (DVar s a da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr))
DVar s b db
zr <- AD s b db -> forall x. ContT x (ST s) (DVar s b db)
forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s b db
z
ST s () -> ContT (DVar s b db) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s b db) (ST s) ())
-> ST s () -> ContT (DVar s b db) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (db -> db -> db
forall a b. a -> b -> a
const db
oneb))
DVar s b db -> ContT (DVar s b db) (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
(D b
z db
_) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
(D a
_ da
x_bar) <- DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
(b, da) -> ST s (b, da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, da
x_bar)
rad2g :: da
-> db
-> dc
-> (forall s . AD s a da -> AD s b db -> AD s c dc)
-> a -> b
-> (c, (da, db))
rad2g :: da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g da
zeroa db
zerob dc
onec forall s. AD s a da -> AD s b db -> AD s c dc
f a
x b
y = (forall s. ST s (c, (da, db))) -> (c, (da, db))
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (c, (da, db))) -> (c, (da, db)))
-> (forall s. ST s (c, (da, db))) -> (c, (da, db))
forall a b. (a -> b) -> a -> b
$ do
DVar s a da
xr <- a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
DVar s b db
yr <- b -> db -> ST s (DVar s b db)
forall a da s. a -> da -> ST s (DVar s a da)
var b
y db
zerob
DVar s c dc
zr' <- ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc))
-> ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc)
forall a b. (a -> b) -> a -> b
$
ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc))
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall a b. (a -> b) -> a -> b
$ do
let
z :: AD s c dc
z = AD s a da -> AD s b db -> AD s c dc
forall s. AD s a da -> AD s b db -> AD s c dc
f ((forall x. ContT x (ST s) (DVar s a da)) -> AD s a da
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (DVar s a da -> ContT x (ST s) (DVar s a da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr)) ((forall x. ContT x (ST s) (DVar s b db)) -> AD s b db
forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (DVar s b db -> ContT x (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
yr))
DVar s c dc
zr <- AD s c dc -> forall x. ContT x (ST s) (DVar s c dc)
forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s c dc
z
ST s () -> ContT (DVar s c dc) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s c dc) (ST s) ())
-> ST s () -> ContT (DVar s c dc) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s c dc -> (D c dc -> D c dc) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s c dc
zr ((dc -> dc) -> D c dc -> D c dc
forall b c a. (b -> c) -> D a b -> D a c
withD (dc -> dc -> dc
forall a b. a -> b -> a
const dc
onec))
DVar s c dc -> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s c dc
zr
(D c
z dc
_) <- DVar s c dc -> ST s (D c dc)
forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
zr'
(D a
_ da
x_bar) <- DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
(D b
_ db
y_bar) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
yr
(c, (da, db)) -> ST s (c, (da, db))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (c
z, (da
x_bar, db
y_bar))
radNg :: Traversable t =>
da
-> db
-> (forall s . t (AD s a da) -> AD s b db)
-> t a
-> (b, t da)
radNg :: da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg da
zeroa db
onea forall s. t (AD s a da) -> AD s b db
f t a
xs = (forall s. ST s (b, t da)) -> (b, t da)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (b, t da)) -> (b, t da))
-> (forall s. ST s (b, t da)) -> (b, t da)
forall a b. (a -> b) -> a -> b
$ do
t (DVar s a da)
xrs <- (a -> ST s (DVar s a da)) -> t a -> ST s (t (DVar s a da))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
`var` da
zeroa) t a
xs
DVar s b db
zr' <- ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall a b. (a -> b) -> a -> b
$
ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ do
let
(AD0 forall x. ContT x (ST s) (DVar s b db)
z) = t (AD s a da) -> AD0 s (DVar s b db)
forall s. t (AD s a da) -> AD s b db
f ((DVar s a da -> AD s a da) -> t (DVar s a da) -> t (AD s a da)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DVar s a da -> AD s a da
forall (f :: * -> *) a. Applicative f => a -> f a
pure t (DVar s a da)
xrs)
DVar s b db
zr <- ContT (DVar s b db) (ST s) (DVar s b db)
forall x. ContT x (ST s) (DVar s b db)
z
ST s () -> ContT (DVar s b db) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s b db) (ST s) ())
-> ST s () -> ContT (DVar s b db) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (db -> db -> db
forall a b. a -> b -> a
const db
onea))
DVar s b db -> ContT (DVar s b db) (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
(D b
z db
_) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
t (D a da)
xs_bar <- (DVar s a da -> ST s (D a da))
-> t (DVar s a da) -> ST s (t (D a da))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef t (DVar s a da)
xrs
let xs_bar_d :: t da
xs_bar_d = D a da -> da
forall a da. D a da -> da
dual (D a da -> da) -> t (D a da) -> t da
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t (D a da)
xs_bar
(b, t da) -> ST s (b, t da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, t da
xs_bar_d)
rad1 :: (Num a, Num b) =>
(forall s . AD' s a -> AD' s b)
-> a
-> (b, a)
rad1 :: (forall s. AD' s a -> AD' s b) -> a -> (b, a)
rad1 = a -> b -> (forall s. AD' s a -> AD' s b) -> a -> (b, a)
forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g a
0 b
1
rad2 :: (Num a, Num b, Num c) =>
(forall s . AD' s a -> AD' s b -> AD' s c)
-> a
-> b
-> (c, (a, b))
rad2 :: (forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
rad2 = a
-> b
-> c
-> (forall s. AD' s a -> AD' s b -> AD' s c)
-> a
-> b
-> (c, (a, b))
forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g a
0 b
0 c
1
grad :: (Traversable t, Num a, Num b) =>
(forall s . t (AD' s a) -> AD' s b)
-> t a
-> (b, t a)
grad :: (forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
grad = a -> b -> (forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
forall (t :: * -> *) da db a b.
Traversable t =>
da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg a
0 b
1
data Backprop a da = Backprop {
Backprop a da -> a -> da
zero :: a -> da
, Backprop a da -> da -> da
one :: da -> da
, Backprop a da -> da -> da -> da
add :: da -> da -> da
}
bpNum :: (Num a, Num da) => Backprop a da
bpNum :: Backprop a da
bpNum = (a -> da) -> (da -> da) -> (da -> da -> da) -> Backprop a da
forall a da.
(a -> da) -> (da -> da) -> (da -> da -> da) -> Backprop a da
Backprop a -> da
forall da a. Num da => a -> da
zeroNum da -> da
forall da a. Num da => a -> da
oneNum da -> da -> da
forall a. Num a => a -> a -> a
addNum
zeroNum :: Num da => a -> da
zeroNum :: a -> da
zeroNum a
_ = da
0
{-# INLINE zeroNum #-}
addNum :: Num da => da -> da -> da
addNum :: da -> da -> da
addNum = da -> da -> da
forall a. Num a => a -> a -> a
(+)
{-# INLINE addNum #-}
oneNum :: Num da => a -> da
oneNum :: a -> da
oneNum a
_ = da
1
{-# INLINE oneNum #-}