module Numeric.AD.Reverse
(
grad
, grad2
, jacobian
, jacobian2
, diffUU
, diff2UU
, diffFU
, diff2FU
, diffUF
, diff2UF
, diff
, diff2
, AD(..)
, Mode(..)
) where
import Control.Applicative ((<$>))
import Data.Traversable (Traversable)
import Numeric.AD.Classes
import Numeric.AD.Internal
import Numeric.AD.Internal.Reverse
grad :: (Traversable f, Num a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> f a
grad f as = unbind vs (partialArray bds $ f vs)
where (vs,bds) = bind as
grad2 :: (Traversable f, Num a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> (a, f a)
grad2 f as = (primal r, unbind vs $ partialArray bds r)
where (vs, bds) = bind as
r = f vs
jacobian :: (Traversable f, Functor g, Num a) => (forall s. Mode s => f (AD s a) -> g (AD s a)) -> f a -> g (f a)
jacobian f as = unbind vs . partialArray bds <$> f vs where
(vs, bds) = bind as
jacobian2 :: (Traversable f, Functor g, Num a) => (forall s. Mode s => f (AD s a) -> g (AD s a)) -> f a -> g (a, f a)
jacobian2 f as = row <$> f vs where
(vs, bds) = bind as
row a = (primal a, unbind vs (partialArray bds a))
diffUU :: Num a => (forall s. Mode s => AD s a -> AD s a) -> a -> a
diffUU f a = derivative $ f (var a 0)
diffUF :: (Functor f, Num a) => (forall s. Mode s => AD s a -> f (AD s a)) -> a -> f a
diffUF f a = derivative <$> f (var a 0)
diff2UU :: Num a => (forall s. Mode s => AD s a -> AD s a) -> a -> (a, a)
diff2UU f a = derivative2 $ f (var a 0)
diff2UF :: (Functor f, Num a) => (forall s. Mode s => AD s a -> f (AD s a)) -> a -> f (a, a)
diff2UF f a = derivative2 <$> f (var a 0)
diffFU :: (Traversable f, Num a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> f a
diffFU f as = unbind vs $ partialArray bds (f vs)
where (vs, bds) = bind as
diff2FU :: (Traversable f, Num a) => (forall s. Mode s => f (AD s a) -> AD s a) -> f a -> (a, f a)
diff2FU f as = (primal result, unbind vs $ partialArray bds result)
where (vs, bds) = bind as
result = f vs
diff :: Num a => (forall s. Mode s => AD s a -> AD s a) -> a -> a
diff = diffUU
diff2 :: Num a => (forall s. Mode s => AD s a -> AD s a) -> a -> (a, a)
diff2 = diff2UU