{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
module Numeric.AD.DelCont.Internal
(rad1, rad2,
auto,
rad1g, rad2g,
op1Num, op2Num,
op1, op2,
AD, AD')
where
import Control.Monad.ST (ST, runST)
import Data.Bifunctor (Bifunctor(..))
import Data.STRef (STRef, newSTRef, readSTRef, modifySTRef')
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Cont (ContT, shiftT, resetT, evalContT)
import Prelude hiding (read)
data D a da = D a da deriving (Int -> D a da -> ShowS
[D a da] -> ShowS
D a da -> String
(Int -> D a da -> ShowS)
-> (D a da -> String) -> ([D a da] -> ShowS) -> Show (D a da)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a da. (Show a, Show da) => Int -> D a da -> ShowS
forall a da. (Show a, Show da) => [D a da] -> ShowS
forall a da. (Show a, Show da) => D a da -> String
showList :: [D a da] -> ShowS
$cshowList :: forall a da. (Show a, Show da) => [D a da] -> ShowS
show :: D a da -> String
$cshow :: forall a da. (Show a, Show da) => D a da -> String
showsPrec :: Int -> D a da -> ShowS
$cshowsPrec :: forall a da. (Show a, Show da) => Int -> D a da -> ShowS
Show, a -> D a b -> D a a
(a -> b) -> D a a -> D a b
(forall a b. (a -> b) -> D a a -> D a b)
-> (forall a b. a -> D a b -> D a a) -> Functor (D a)
forall a b. a -> D a b -> D a a
forall a b. (a -> b) -> D a a -> D a b
forall a a b. a -> D a b -> D a a
forall a a b. (a -> b) -> D a a -> D a b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> D a b -> D a a
$c<$ :: forall a a b. a -> D a b -> D a a
fmap :: (a -> b) -> D a a -> D a b
$cfmap :: forall a a b. (a -> b) -> D a a -> D a b
Functor)
instance Eq a => Eq (D a da) where
D a
x da
_ == :: D a da -> D a da -> Bool
== D a
y da
_ = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
instance Ord a => Ord (D a db) where
compare :: D a db -> D a db -> Ordering
compare (D a
x db
_) (D a
y db
_) = a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
x a
y
instance Bifunctor D where
bimap :: (a -> b) -> (c -> d) -> D a c -> D b d
bimap a -> b
f c -> d
g (D a
a c
b) = b -> d -> D b d
forall a da. a -> da -> D a da
D (a -> b
f a
a) (c -> d
g c
b)
withD :: (da -> db) -> D a da -> D a db
withD :: (da -> db) -> D a da -> D a db
withD = (da -> db) -> D a da -> D a db
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second
type DVar s a da = STRef s (D a da)
var :: a -> da -> ST s (DVar s a da)
var :: a -> da -> ST s (DVar s a da)
var a
x da
dx = D a da -> ST s (DVar s a da)
forall a s. a -> ST s (STRef s a)
newSTRef (a -> da -> D a da
forall a da. a -> da -> D a da
D a
x da
dx)
auto :: a -> AD s a da
auto :: a -> AD s a da
auto a
x = (forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
forall s a da.
(forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
AD ((forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da)
-> (forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
forall a b. (a -> b) -> a -> b
$ ST s (DVar s a da) -> ContT (DVar s x dx) (ST s) (DVar s a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (DVar s a da) -> ContT (DVar s x dx) (ST s) (DVar s a da))
-> ST s (DVar s a da) -> ContT (DVar s x dx) (ST s) (DVar s a da)
forall a b. (a -> b) -> a -> b
$ a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
forall a. HasCallStack => a
undefined
newtype AD s a da = AD { AD s a da -> forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
unAD :: forall x dx . ContT (DVar s x dx) (ST s) (DVar s a da) }
type AD' s a = AD s a a
op1_ :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
zero da -> da -> da
plusa a -> (b, db -> da)
f ContT x (ST s) (DVar s a da)
ioa = do
DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
(D a
xa da
_) <- ST s (D a da) -> ContT x (ST s) (D a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D a da) -> ContT x (ST s) (D a da))
-> ST s (D a da) -> ContT x (ST s) (D a da)
forall a b. (a -> b) -> a -> b
$ DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
let (b
xb, db -> da
g) = a -> (b, db -> da)
f a
xa
((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db)
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT (((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db))
-> ((DVar s b db -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ \ DVar s b db -> ST s x
k -> ST s x -> ContT x (ST s) x
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s x -> ContT x (ST s) x) -> ST s x -> ContT x (ST s) x
forall a b. (a -> b) -> a -> b
$ do
DVar s b db
rb <- b -> db -> ST s (DVar s b db)
forall a da s. a -> da -> ST s (DVar s a da)
var b
xb db
zero
x
ry <- DVar s b db -> ST s x
k DVar s b db
rb
(D b
_ db
yd) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
DVar s a da -> (D a da -> D a da) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra ((da -> da) -> D a da -> D a da
forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` db -> da
g db
yd))
x -> ST s x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry
op1 :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 :: db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
z da -> da -> da
plusa a -> (b, db -> da)
f (AD forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
ioa) = (forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db))
-> AD s b db
forall s a da.
(forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
AD ((forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db))
-> AD s b db)
-> (forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db))
-> AD s b db
forall a b. (a -> b) -> a -> b
$ db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT (DVar s x dx) (ST s) (DVar s a da)
-> ContT (DVar s x dx) (ST s) (DVar s b db)
forall db da a b x s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
op1_ db
z da -> da -> da
plusa a -> (b, db -> da)
f ContT (DVar s x dx) (ST s) (DVar s a da)
forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
ioa
op1Num :: (Num da, Num db) =>
(a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1Num :: (a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num = db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
forall db da a b s.
db
-> (da -> da -> da)
-> (a -> (b, db -> da))
-> AD s a da
-> AD s b db
op1 db
0 da -> da -> da
forall a. Num a => a -> a -> a
(+)
op2_ :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
zero da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT x (ST s) (DVar s a da)
ioa ContT x (ST s) (DVar s b db)
iob = do
DVar s a da
ra <- ContT x (ST s) (DVar s a da)
ioa
DVar s b db
rb <- ContT x (ST s) (DVar s b db)
iob
(D a
xa da
_) <- ST s (D a da) -> ContT x (ST s) (D a da)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D a da) -> ContT x (ST s) (D a da))
-> ST s (D a da) -> ContT x (ST s) (D a da)
forall a b. (a -> b) -> a -> b
$ DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
ra
(D b
xb db
_) <- ST s (D b db) -> ContT x (ST s) (D b db)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (D b db) -> ContT x (ST s) (D b db))
-> ST s (D b db) -> ContT x (ST s) (D b db)
forall a b. (a -> b) -> a -> b
$ DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
rb
let (c
xc, dc -> da
g, dc -> db
h) = a -> b -> (c, dc -> da, dc -> db)
f a
xa b
xb
((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc)
forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> ContT r m r) -> ContT r m a
shiftT (((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc))
-> ((DVar s c dc -> ST s x) -> ContT x (ST s) x)
-> ContT x (ST s) (DVar s c dc)
forall a b. (a -> b) -> a -> b
$ \ DVar s c dc -> ST s x
k -> ST s x -> ContT x (ST s) x
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s x -> ContT x (ST s) x) -> ST s x -> ContT x (ST s) x
forall a b. (a -> b) -> a -> b
$ do
DVar s c dc
rc <- c -> dc -> ST s (DVar s c dc)
forall a da s. a -> da -> ST s (DVar s a da)
var c
xc dc
zero
x
ry <- DVar s c dc -> ST s x
k DVar s c dc
rc
(D c
_ dc
yd) <- DVar s c dc -> ST s (D c dc)
forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
rc
DVar s a da -> (D a da -> D a da) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s a da
ra ((da -> da) -> D a da -> D a da
forall b c a. (b -> c) -> D a b -> D a c
withD (\da
rda0 -> da
rda0 da -> da -> da
`plusa` dc -> da
g dc
yd))
DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
rb ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (\db
rdb0 -> db
rdb0 db -> db -> db
`plusb` dc -> db
h dc
yd))
x -> ST s x
forall (f :: * -> *) a. Applicative f => a -> f a
pure x
ry
op2 :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> (AD s a da -> AD s b db -> AD s c dc)
op2 :: dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f (AD forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
ioa) (AD forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db)
iob) = (forall x dx. ContT (DVar s x dx) (ST s) (DVar s c dc))
-> AD s c dc
forall s a da.
(forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
AD ((forall x dx. ContT (DVar s x dx) (ST s) (DVar s c dc))
-> AD s c dc)
-> (forall x dx. ContT (DVar s x dx) (ST s) (DVar s c dc))
-> AD s c dc
forall a b. (a -> b) -> a -> b
$ dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT (DVar s x dx) (ST s) (DVar s a da)
-> ContT (DVar s x dx) (ST s) (DVar s b db)
-> ContT (DVar s x dx) (ST s) (DVar s c dc)
forall dc da db a b c x s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> ContT x (ST s) (DVar s a da)
-> ContT x (ST s) (DVar s b db)
-> ContT x (ST s) (DVar s c dc)
op2_ dc
z da -> da -> da
plusa db -> db -> db
plusb a -> b -> (c, dc -> da, dc -> db)
f ContT (DVar s x dx) (ST s) (DVar s a da)
forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
ioa ContT (DVar s x dx) (ST s) (DVar s b db)
forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db)
iob
op2Num :: (Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2Num :: (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num = dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
forall dc da db a b c s.
dc
-> (da -> da -> da)
-> (db -> db -> db)
-> (a -> b -> (c, dc -> da, dc -> db))
-> AD s a da
-> AD s b db
-> AD s c dc
op2 dc
0 da -> da -> da
forall a. Num a => a -> a -> a
(+) db -> db -> db
forall a. Num a => a -> a -> a
(+)
instance (Num a) => Num (AD s a a) where
+ :: AD s a a -> AD s a a -> AD s a a
(+) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y, a -> a
forall a. a -> a
id, a -> a
forall a. a -> a
id)
(-) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
y, a -> a
forall a. a -> a
id, a -> a
forall a. Num a => a -> a
negate)
* :: AD s a a -> AD s a a -> AD s a a
(*) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
y, (a
y a -> a -> a
forall a. Num a => a -> a -> a
*), (a
x a -> a -> a
forall a. Num a => a -> a -> a
*))
fromInteger :: Integer -> AD s a a
fromInteger Integer
x = a -> AD s a a
forall a s da. a -> AD s a da
auto (Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
x)
abs :: AD s a a -> AD s a a
abs = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Num a => a -> a
abs a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Num a => a -> a
signum a
x))
signum :: AD s a a -> AD s a a
signum = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Num a => a -> a
signum a
x, a -> a -> a
forall a b. a -> b -> a
const a
0)
instance (Fractional a) => Fractional (AD s a a) where
/ :: AD s a a -> AD s a a -> AD s a a
(/) = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y -> (a
x a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
y), (\a
g -> -a
ga -> a -> a
forall a. Num a => a -> a -> a
*a
xa -> a -> a
forall a. Fractional a => a -> a -> a
/(a
ya -> a -> a
forall a. Num a => a -> a -> a
*a
y) ))
fromRational :: Rational -> AD s a a
fromRational Rational
x = a -> AD s a a
forall a s da. a -> AD s a da
auto (Rational -> a
forall a. Fractional a => Rational -> a
fromRational Rational
x)
recip :: AD s a a -> AD s a a
recip = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Fractional a => a -> a
recip a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate)
instance Floating a => Floating (AD s a a) where
pi :: AD s a a
pi = a -> AD s a a
forall a s da. a -> AD s a da
auto a
forall a. Floating a => a
pi
exp :: AD s a a -> AD s a a
exp = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
exp a
x, (a -> a
forall a. Floating a => a -> a
exp a
x a -> a -> a
forall a. Num a => a -> a -> a
*))
log :: AD s a a -> AD s a a
log = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
log a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/a
x))
sqrt :: AD s a a -> AD s a a
sqrt = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sqrt a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
x)))
logBase :: AD s a a -> AD s a a -> AD s a a
logBase = (a -> a -> (a, a -> a, a -> a)) -> AD s a a -> AD s a a -> AD s a a
forall da db dc a b c s.
(Num da, Num db, Num dc) =>
(a -> b -> (c, dc -> da, dc -> db))
-> AD s a da -> AD s b db -> AD s c dc
op2Num ((a -> a -> (a, a -> a, a -> a))
-> AD s a a -> AD s a a -> AD s a a)
-> (a -> a -> (a, a -> a, a -> a))
-> AD s a a
-> AD s a a
-> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x a
y ->
let
dx :: a
dx = - a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
x a
y a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a -> a
forall a. Floating a => a -> a
log a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
x)
in ( a -> a -> a
forall a. Floating a => a -> a -> a
logBase a
x a
y
, ( a -> a -> a
forall a. Num a => a -> a -> a
* a
dx)
, (a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
y a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
x))
)
sin :: AD s a a -> AD s a a
sin = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sin a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cos a
x))
cos :: AD s a a -> AD s a a
cos = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
cos a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* (-a -> a
forall a. Floating a => a -> a
sin a
x)))
tan :: AD s a a -> AD s a a
tan = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
tan a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
cos a
xa -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
asin :: AD s a a -> AD s a a
asin = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
asin a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt(a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)))
acos :: AD s a a -> AD s a a
acos = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
acos a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)) (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Num a => a -> a
negate)
atan :: AD s a a -> AD s a a
atan = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
atan a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)))
sinh :: AD s a a -> AD s a a
sinh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
sinh a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cosh a
x))
cosh :: AD s a a -> AD s a a
cosh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
cosh a
x, (a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sinh a
x))
tanh :: AD s a a -> AD s a a
tanh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
tanh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
cosh a
xa -> Int -> a
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)))
asinh :: AD s a a -> AD s a a
asinh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
asinh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)))
acosh :: AD s a a -> AD s a a
acosh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
acosh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> a
forall a. Floating a => a -> a
sqrt (a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
- a
1)))
atanh :: AD s a a -> AD s a a
atanh = (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall da db a b s.
(Num da, Num db) =>
(a -> (b, db -> da)) -> AD s a da -> AD s b db
op1Num ((a -> (a, a -> a)) -> AD s a a -> AD s a a)
-> (a -> (a, a -> a)) -> AD s a a -> AD s a a
forall a b. (a -> b) -> a -> b
$ \a
x -> (a -> a
forall a. Floating a => a -> a
atanh a
x, (a -> a -> a
forall a. Fractional a => a -> a -> a
/ (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)))
rad1g :: da
-> db
-> (forall s . AD s a da -> AD s b db)
-> a
-> (b, da)
rad1g :: da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g da
zero db
one forall s. AD s a da -> AD s b db
f a
x = (forall s. ST s (b, da)) -> (b, da)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (b, da)) -> (b, da))
-> (forall s. ST s (b, da)) -> (b, da)
forall a b. (a -> b) -> a -> b
$ do
DVar s a da
xr <- a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zero
DVar s b db
zr' <- ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db) -> ST s (DVar s b db)
forall a b. (a -> b) -> a -> b
$
ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db))
-> ContT (DVar s b db) (ST s) (DVar s b db)
-> ContT (DVar s b db) (ST s) (DVar s b db)
forall a b. (a -> b) -> a -> b
$ do
let
z :: AD s b db
z = AD s a da -> AD s b db
forall s. AD s a da -> AD s b db
f ((forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
forall s a da.
(forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
AD (DVar s a da -> ContT (DVar s x dx) (ST s) (DVar s a da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr))
DVar s b db
zr <- AD s b db -> forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db)
forall s a da.
AD s a da -> forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
unAD AD s b db
z
ST s () -> ContT (DVar s b db) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s b db) (ST s) ())
-> ST s () -> ContT (DVar s b db) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s b db -> (D b db -> D b db) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s b db
zr ((db -> db) -> D b db -> D b db
forall b c a. (b -> c) -> D a b -> D a c
withD (db -> db -> db
forall a b. a -> b -> a
const db
one))
DVar s b db -> ContT (DVar s b db) (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
zr
(D b
z db
_) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
zr'
(D a
_ da
x_bar) <- DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
(b, da) -> ST s (b, da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b
z, da
x_bar)
rad2g :: da
-> db
-> dc
-> (forall s . AD s a da -> AD s b db -> AD s c dc)
-> a -> b
-> (c, (da, db))
rad2g :: da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g da
zeroa db
zerob dc
one forall s. AD s a da -> AD s b db -> AD s c dc
f a
x b
y = (forall s. ST s (c, (da, db))) -> (c, (da, db))
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (c, (da, db))) -> (c, (da, db)))
-> (forall s. ST s (c, (da, db))) -> (c, (da, db))
forall a b. (a -> b) -> a -> b
$ do
DVar s a da
xr <- a -> da -> ST s (DVar s a da)
forall a da s. a -> da -> ST s (DVar s a da)
var a
x da
zeroa
DVar s b db
yr <- b -> db -> ST s (DVar s b db)
forall a da s. a -> da -> ST s (DVar s a da)
var b
y db
zerob
DVar s c dc
zr' <- ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc))
-> ContT (DVar s c dc) (ST s) (DVar s c dc) -> ST s (DVar s c dc)
forall a b. (a -> b) -> a -> b
$
ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall (m :: * -> *) r r'. Monad m => ContT r m r -> ContT r' m r
resetT (ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc))
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
-> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall a b. (a -> b) -> a -> b
$ do
let
z :: AD s c dc
z = AD s a da -> AD s b db -> AD s c dc
forall s. AD s a da -> AD s b db -> AD s c dc
f ((forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
forall s a da.
(forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
AD (DVar s a da -> ContT (DVar s x dx) (ST s) (DVar s a da)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s a da
xr)) ((forall x dx. ContT (DVar s x dx) (ST s) (DVar s b db))
-> AD s b db
forall s a da.
(forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da))
-> AD s a da
AD (DVar s b db -> ContT (DVar s x dx) (ST s) (DVar s b db)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s b db
yr))
DVar s c dc
zr <- AD s c dc -> forall x dx. ContT (DVar s x dx) (ST s) (DVar s c dc)
forall s a da.
AD s a da -> forall x dx. ContT (DVar s x dx) (ST s) (DVar s a da)
unAD AD s c dc
z
ST s () -> ContT (DVar s c dc) (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> ContT (DVar s c dc) (ST s) ())
-> ST s () -> ContT (DVar s c dc) (ST s) ()
forall a b. (a -> b) -> a -> b
$ DVar s c dc -> (D c dc -> D c dc) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef' DVar s c dc
zr ((dc -> dc) -> D c dc -> D c dc
forall b c a. (b -> c) -> D a b -> D a c
withD (dc -> dc -> dc
forall a b. a -> b -> a
const dc
one))
DVar s c dc -> ContT (DVar s c dc) (ST s) (DVar s c dc)
forall (f :: * -> *) a. Applicative f => a -> f a
pure DVar s c dc
zr
(D c
z dc
_) <- DVar s c dc -> ST s (D c dc)
forall s a. STRef s a -> ST s a
readSTRef DVar s c dc
zr'
(D a
_ da
x_bar) <- DVar s a da -> ST s (D a da)
forall s a. STRef s a -> ST s a
readSTRef DVar s a da
xr
(D b
_ db
y_bar) <- DVar s b db -> ST s (D b db)
forall s a. STRef s a -> ST s a
readSTRef DVar s b db
yr
(c, (da, db)) -> ST s (c, (da, db))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (c
z, (da
x_bar, db
y_bar))
rad1 :: (Num a, Num b) =>
(forall s . AD' s a -> AD' s b)
-> a
-> (b, a)
rad1 :: (forall s. AD' s a -> AD' s b) -> a -> (b, a)
rad1 = a -> b -> (forall s. AD' s a -> AD' s b) -> a -> (b, a)
forall da db a b.
da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
rad1g a
0 b
1
rad2 :: (Num a, Num b, Num c) =>
(forall s . AD' s a -> AD' s b -> AD' s c)
-> a
-> b
-> (c, (a, b))
rad2 :: (forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
rad2 = a
-> b
-> c
-> (forall s. AD' s a -> AD' s b -> AD' s c)
-> a
-> b
-> (c, (a, b))
forall da db dc a b c.
da
-> db
-> dc
-> (forall s. AD s a da -> AD s b db -> AD s c dc)
-> a
-> b
-> (c, (da, db))
rad2g a
0 b
0 c
1