{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -Wno-unused-imports -Wno-unused-top-binds #-}
module Numeric.AD.DelCont.Internal
(rad1, rad2, grad,
auto,
rad1g, rad2g, radNg,
op1Num, op2Num,
op1, op2,
AD0, AD, AD')
where
import Control.Monad.ST (ST, runST)
import Data.Bifunctor (Bifunctor(..))
import Data.STRef (STRef, newSTRef, readSTRef, modifySTRef')
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Cont (ContT, shiftT, resetT, evalContT)
import Prelude hiding (read)
data D a da = D { forall a da. D a da -> a
primal :: a, forall a da. D a da -> da
dual :: da } deriving (Int -> D a da -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a da. (Show a, Show da) => Int -> D a da -> ShowS
forall a da. (Show a, Show da) => [D a da] -> ShowS
forall a da. (Show a, Show da) => D a da -> String
showList :: [D a da] -> ShowS
$cshowList :: forall a da. (Show a, Show da) => [D a da] -> ShowS
show :: D a da -> String
$cshow :: forall a da. (Show a, Show da) => D a da -> String
showsPrec :: Int -> D a da -> ShowS
$cshowsPrec :: forall a da. (Show a, Show da) => Int -> D a da -> ShowS
Show, forall a b. a -> D a b -> D a a
forall a b. (a -> b) -> D a a -> D a b
forall a a b. a -> D a b -> D a a
forall a a b. (a -> b) -> D a a -> D a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> D a b -> D a a
$c<$ :: forall a a b. a -> D a b -> D a a
fmap :: forall a b. (a -> b) -> D a a -> D a b
$cfmap :: forall a a b. (a -> b) -> D a a -> D a b
Functor)
instance Eq a => Eq (D a da) where
D a
x da
_ == :: D a da -> D a da -> Bool
== D a
y da
_ = a
x forall a. Eq a => a -> a -> Bool
== a
y
instance Ord a => Ord (D a db) where
compare :: D a db -> D a db -> Ordering
compare (D a
x db
_) (D a
y db
_) = forall a. Ord a => a -> a -> Ordering
compare a
x a
y
instance Bifunctor D where
bimap :: forall a b c d. (a -> b) -> (c -> d) -> D a c -> D b d
bimap a -> b
f c -> d
g (D a
a c
b) = forall a da. a -> da -> D a da
D (a -> b
f a
a) (c -> d
g c
b)
withD :: (da -> db) -> D a da -> D a db
withD :: forall b c a. (b -> c) -> D a b -> D a c
withD = forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second
type DVar s a da = STRef s (D a da)
var :: a -> da -> ST s (DVar s a da)
var :: forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
dx = forall a s. a -> ST s (STRef s a)
newSTRef (forall a da. a -> da -> D a da
D a
x da
dx)
auto :: a -> AD s a da
auto :: forall a s da. a -> AD s a da
auto a
x = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a da s. a -> da -> ST s (DVar s a da)
var a
x forall a. HasCallStack => a
undefined
autoStrict :: a -> da -> AD0 s (DVar s a da)
autoStrict :: forall a da s. a -> da -> AD0 s (DVar s a da)
autoStrict a
x da
dx = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
dx
newtype AD0 s a = AD0 { forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 :: forall x . ContT x (ST s) a } deriving (forall a b. a -> AD0 s b -> AD0 s a
forall a b. (a -> b) -> AD0 s a -> AD0 s b
forall s a b. a -> AD0 s b -> AD0 s a
forall s a b. (a -> b) -> AD0 s a -> AD0 s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> AD0 s b -> AD0 s a
$c<$ :: forall s a b. a -> AD0 s b -> AD0 s a
fmap :: forall a b. (a -> b) -> AD0 s a -> AD0 s b
$cfmap :: forall s a b. (a -> b) -> AD0 s a -> AD0 s b
Functor)
instance Applicative (AD0 s) where
AD0 forall x. ContT x (ST s) (a -> b)
f <*> :: forall a b. AD0 s (a -> b) -> AD0 s a -> AD0 s b
<*> AD0 forall x. ContT x (ST s) a
x = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall x. ContT x (ST s) (a -> b)
f forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall x. ContT x (ST s) a
x)
pure :: forall a. a -> AD0 s a
pure a
x = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
type AD s a da = AD0 s (DVar s a da)
type AD' s a = AD s a a
op1_ :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ :: forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
zeroa da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
ioa = do
DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
(D a
xa da
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
let (b
xb, db -> da
g) = a -> (b, db -> da)
f a
xa
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT forall a b. (a -> b) -> a -> b
$ \ DVar s b db -> ST s x
k -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
DVar s b db
rb <- forall a da s. a -> da -> ST s (DVar s a da)
var b
xb db
zeroa
x
ry <- DVar s b db -> ST s x
k DVar s b db
rb
(D b
_ db
yd) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra (forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` db -> da
g db
yd))
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry
op1 :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 :: forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
z da -> da -> da
plusa a -> (b, db -> da)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
z da -> da -> da
plusa a -> (b, db -> da)
f forall x. ContT x (ST s) (DVar s a da)
ioa
op1Num :: (Num da, Num db) =>
(a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1Num :: forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num = forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
0 forall a. Num a => a -> a -> a
(+)
op2_ :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ :: forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
zeroa da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
iob = do
DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
DVar s b db
rb <- ContT x (ST s) (DVar s b db)
iob
(D a
xa da
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
(D b
xb db
_) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
let (c
xc, dc -> da
g, dc -> db
h) = a -> b -> (c, dc -> da, dc -> db)
f a
xa b
xb
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT forall a b. (a -> b) -> a -> b
$ \ DVar s c dc -> ST s x
k -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
DVar s c dc
rc <- forall a da s. a -> da -> ST s (DVar s a da)
var c
xc dc
zeroa
x
ry <- DVar s c dc -> ST s x
k DVar s c dc
rc
(D c
_ dc
yd) <- forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
rc
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra (forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` dc -> da
g dc
yd))
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
rb (forall b c a. (b -> c) -> D a b -> D a c
withD (\db
rdb0 -> db
rdb0 db -> db -> db
`plusb` dc -> db
h dc
yd))
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry
op2 :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> (AD s a da -> AD s b db -> AD s c dc)
op2 :: forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f (AD0 forall x. ContT x (ST s) (DVar s a da)
ioa) (AD0 forall x. ContT x (ST s) (DVar s b db)
iob) = forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 forall a b. (a -> b) -> a -> b
$ forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f forall x. ContT x (ST s) (DVar s a da)
ioa forall x. ContT x (ST s) (DVar s b db)
iob
op2Num :: (Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2Num :: forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num = forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
0 forall a. Num a => a -> a -> a
(+) forall a. Num a => a -> a -> a
(+)
instance (Num a) => Num (AD s a a) where
+ :: AD s a a -> AD s a a -> AD s a a
(+) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Num a => a -> a -> a
+ a
y, forall a. a -> a
id, forall a. a -> a
id)
(-) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Num a => a -> a -> a
- a
y, forall a. a -> a
id, forall a. Num a => a -> a
negate)
* :: AD s a a -> AD s a a -> AD s a a
(*) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
xforall a. Num a => a -> a -> a
*a
y, (a
y forall a. Num a => a -> a -> a
*), (a
x forall a. Num a => a -> a -> a
*))
fromInteger :: Integer -> AD s a a
fromInteger Integer
x = forall a s da. a -> AD s a da
auto (forall a. Num a => Integer -> a
fromInteger Integer
x)
abs :: AD s a a -> AD s a a
abs = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Num a => a -> a
abs a
x, (forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
signum a
x))
signum :: AD s a a -> AD s a a
signum = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Num a => a -> a
signum a
x, forall a b. a -> b -> a
const a
0)
instance (Fractional a) => Fractional (AD s a a) where
/ :: AD s a a -> AD s a a -> AD s a a
(/) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Fractional a => a -> a -> a
/ a
y, (forall a. Fractional a => a -> a -> a
/ a
y), (\a
g -> -a
gforall a. Num a => a -> a -> a
*a
xforall a. Fractional a => a -> a -> a
/(a
yforall a. Num a => a -> a -> a
*a
y) ))
fromRational :: Rational -> AD s a a
fromRational Rational
x = forall a s da. a -> AD s a da
auto (forall a. Fractional a => Rational -> a
fromRational Rational
x)
recip :: AD s a a -> AD s a a
recip = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Fractional a => a -> a
recip a
x, (forall a. Fractional a => a -> a -> a
/(a
xforall a. Num a => a -> a -> a
*a
x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate)
instance Floating a => Floating (AD s a a) where
pi :: AD s a a
pi = forall a s da. a -> AD s a da
auto forall a. Floating a => a
pi
exp :: AD s a a -> AD s a a
exp = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
exp a
x, (forall a. Floating a => a -> a
exp a
x forall a. Num a => a -> a -> a
*))
log :: AD s a a -> AD s a a
log = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
log a
x, (forall a. Fractional a => a -> a -> a
/a
x))
sqrt :: AD s a a -> AD s a a
sqrt = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
sqrt a
x, (forall a. Fractional a => a -> a -> a
/ (a
2 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt a
x)))
** :: AD s a a -> AD s a a -> AD s a a
(**) = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x forall a. Floating a => a -> a -> a
** a
y, (forall a. Num a => a -> a -> a
* (a
y forall a. Num a => a -> a -> a
* a
x forall a. Floating a => a -> a -> a
** (a
y forall a. Num a => a -> a -> a
- a
1))), (forall a. Num a => a -> a -> a
* (a
x forall a. Floating a => a -> a -> a
** a
y forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
x)))
logBase :: AD s a a -> AD s a a -> AD s a a
logBase = forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
let
dx :: a
dx = - forall a. Floating a => a -> a -> a
logBase a
x a
y forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
log a
x forall a. Num a => a -> a -> a
* a
x)
in ( forall a. Floating a => a -> a -> a
logBase a
x a
y
, ( forall a. Num a => a -> a -> a
* a
dx)
, (forall a. Fractional a => a -> a -> a
/(a
y forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
log a
x))
)
sin :: AD s a a -> AD s a a
sin = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
sin a
x, (forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
cos a
x))
cos :: AD s a a -> AD s a a
cos = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
cos a
x, (forall a. Num a => a -> a -> a
* (-forall a. Floating a => a -> a
sin a
x)))
tan :: AD s a a -> AD s a a
tan = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
tan a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
cos a
xforall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
asin :: AD s a a -> AD s a a
asin = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
asin a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt(a
1 forall a. Num a => a -> a -> a
- a
xforall a. Num a => a -> a -> a
*a
x)))
acos :: AD s a a -> AD s a a
acos = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
acos a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (a
1 forall a. Num a => a -> a -> a
- a
xforall a. Num a => a -> a -> a
*a
x)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
negate)
atan :: AD s a a -> AD s a a
atan = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
atan a
x, (forall a. Fractional a => a -> a -> a
/ (a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
+ a
1)))
sinh :: AD s a a -> AD s a a
sinh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
sinh a
x, (forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
cosh a
x))
cosh :: AD s a a -> AD s a a
cosh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
cosh a
x, (forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sinh a
x))
tanh :: AD s a a -> AD s a a
tanh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
tanh a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
cosh a
xforall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
asinh :: AD s a a -> AD s a a
asinh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
asinh a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
+ a
1)))
acosh :: AD s a a -> AD s a a
acosh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
acosh a
x, (forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt (a
xforall a. Num a => a -> a -> a
*a
x forall a. Num a => a -> a -> a
- a
1)))
atanh :: AD s a a -> AD s a a
atanh = forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num forall a b. (a -> b) -> a -> b
$ \a
x -> (forall a. Floating a => a -> a
atanh a
x, (forall a. Fractional a => a -> a -> a
/ (a
1 forall a. Num a => a -> a -> a
- a
xforall a. Num a => a -> a -> a
*a
x)))
rad1g :: da
-> db
-> (forall s . AD s a da -> AD s b db)
-> a
-> (b, da)
rad1g :: forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g da
zeroa db
oneb forall s. AD s a da -> AD s b db
f a
x = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
DVar s a da
xr <- forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
DVar s b db
zr' <- forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT forall a b. (a -> b) -> a -> b
$ do
let
z :: AD s b db
z = forall s. AD s a da -> AD s b db
f (forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr))
DVar s b db
zr <- forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s b db
z
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr (forall b c a. (b -> c) -> D a b -> D a c
withD (forall a b. a -> b -> a
const db
oneb))
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
(D b
z db
_) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
(D a
_ da
x_bar) <- forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, da
x_bar)
rad2g :: da
-> db
-> dc
-> (forall s . AD s a da -> AD s b db -> AD s c dc)
-> a -> b
-> (c, (da, db))
rad2g :: forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g da
zeroa db
zerob dc
onec forall s. AD s a da -> AD s b db -> AD s c dc
f a
x b
y = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
DVar s a da
xr <- forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
DVar s b db
yr <- forall a da s. a -> da -> ST s (DVar s a da)
var b
y db
zerob
DVar s c dc
zr' <- forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT forall a b. (a -> b) -> a -> b
$ do
let
z :: AD s c dc
z = forall s. AD s a da -> AD s b db -> AD s c dc
f (forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr)) (forall s a. (forall x. ContT x (ST s) a) -> AD0 s a
AD0 (forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
yr))
DVar s c dc
zr <- forall s a. AD0 s a -> forall x. ContT x (ST s) a
unAD0 AD s c dc
z
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s c dc
zr (forall b c a. (b -> c) -> D a b -> D a c
withD (forall a b. a -> b -> a
const dc
onec))
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s c dc
zr
(D c
z dc
_) <- forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
zr'
(D a
_ da
x_bar) <- forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
(D b
_ db
y_bar) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
yr
forall (f :: * -> *) a. Applicative f => a -> f a
pure (c
z, (da
x_bar, db
y_bar))
radNg :: Traversable t =>
da
-> db
-> (forall s . t (AD s a da) -> AD s b db)
-> t a
-> (b, t da)
radNg :: forall (t :: * -> *) da db a b.
Traversable t =>
da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg da
zeroa db
onea forall s. t (AD s a da) -> AD s b db
f t a
xs = forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
t (DVar s a da)
xrs <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a da s. a -> da -> ST s (DVar s a da)
`var` da
zeroa) t a
xs
DVar s b db
zr' <- forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT forall a b. (a -> b) -> a -> b
$ do
let
(AD0 forall x. ContT x (ST s) (DVar s b db)
z) = forall s. t (AD s a da) -> AD s b db
f (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (f :: * -> *) a. Applicative f => a -> f a
pure t (DVar s a da)
xrs)
DVar s b db
zr <- forall x. ContT x (ST s) (DVar s b db)
z
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr (forall b c a. (b -> c) -> D a b -> D a c
withD (forall a b. a -> b -> a
const db
onea))
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
(D b
z db
_) <- forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
t (D a da)
xs_bar <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall s a. STRef s a -> ST s a
readSTRef t (DVar s a da)
xrs
let xs_bar_d :: t da
xs_bar_d = forall a da. D a da -> da
dual forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t (D a da)
xs_bar
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, t da
xs_bar_d)
for :: (Applicative f, Traversable t) => t a -> (a -> f b) -> f (t b)
for :: forall (f :: * -> *) (t :: * -> *) a b.
(Applicative f, Traversable t) =>
t a -> (a -> f b) -> f (t b)
for = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse
rad1 :: (Num a, Num b) =>
(forall s . AD' s a -> AD' s b)
-> a
-> (b, a)
rad1 :: forall a b.
(Num a, Num b) =>
(forall s. AD' s a -> AD' s b) -> a -> (b, a)
rad1 = forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g a
0 b
1
rad2 :: (Num a, Num b, Num c) =>
(forall s . AD' s a -> AD' s b -> AD' s c)
-> a
-> b
-> (c, (a, b))
rad2 :: forall a b c.
(Num a, Num b, Num c) =>
(forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
rad2 = forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g a
0 b
0 c
1
grad :: (Traversable t, Num a, Num b) =>
(forall s . t (AD' s a) -> AD' s b)
-> t a
-> (b, t a)
grad :: forall (t :: * -> *) a b.
(Traversable t, Num a, Num b) =>
(forall s. t (AD' s a) -> AD' s b) -> t a -> (b, t a)
grad = forall (t :: * -> *) da db a b.
Traversable t =>
da
-> db -> (forall s. t (AD s a da) -> AD s b db) -> t a -> (b, t da)
radNg a
0 b
1
data Backprop a da = Backprop {
forall a da. Backprop a da -> a -> da
zero :: a -> da
, forall a da. Backprop a da -> da -> da
one :: da -> da
, forall a da. Backprop a da -> da -> da -> da
add :: da -> da -> da
}
bpNum :: (Num a, Num da) => Backprop a da
bpNum :: forall a da. (Num a, Num da) => Backprop a da
bpNum = forall a da.
(a -> da) -> (da -> da) -> (da -> da -> da) -> Backprop a da
Backprop forall da a. Num da => a -> da
zeroNum forall da a. Num da => a -> da
oneNum forall a. Num a => a -> a -> a
addNum
zeroNum :: Num da => a -> da
zeroNum :: forall da a. Num da => a -> da
zeroNum a
_ = da
0
{-# INLINE zeroNum #-}
addNum :: Num da => da -> da -> da
addNum :: forall a. Num a => a -> a -> a
addNum = forall a. Num a => a -> a -> a
(+)
{-# INLINE addNum #-}
oneNum :: Num da => a -> da
oneNum :: forall da a. Num da => a -> da
oneNum a
_ = da
1
{-# INLINE oneNum #-}