{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE BangPatterns #-}
-----------------------------------------------------------------------------
-- |
-- Copyright   : (c) Edward Kmett 2010-2021
-- License     : BSD3
-- Maintainer  : ekmett@gmail.com
-- Stability   : experimental
-- Portability : GHC only
--
-- Higher order derivatives via a \"dual number tower\".
--
-----------------------------------------------------------------------------

module Numeric.AD.Rank1.Tower
  ( Tower
  , auto
  -- * Taylor Series
  , taylor
  , taylor0
  -- * Maclaurin Series
  , maclaurin
  , maclaurin0
  -- * Derivatives
  , diff    -- first derivative of (a -> a)
  , diff'   -- answer and first derivative of (a -> a)
  , diffs   -- answer and all derivatives of (a -> a)
  , diffs0  -- zero padded derivatives of (a -> a)
  , diffsF  -- answer and all derivatives of (a -> f a)
  , diffs0F -- zero padded derivatives of (a -> f a)
  -- * Directional Derivatives
  , du      -- directional derivative of (f a -> a)
  , du'     -- answer and directional derivative of (f a -> a)
  , dus     -- answer and all directional derivatives of (f a -> a)
  , dus0    -- answer and all zero padded directional derivatives of (f a -> a)
  , duF     -- directional derivative of (f a -> g a)
  , duF'    -- answer and directional derivative of (f a -> g a)
  , dusF    -- answer and all directional derivatives of (f a -> g a)
  , dus0F   -- answer and all zero padded directional derivatives of (f a -> g a)
  ) where

import Numeric.AD.Internal.Tower
import Numeric.AD.Mode

-- | Compute the answer and all derivatives of a function @(a -> a)@
diffs
  :: Num a
  => (Tower a -> Tower a)
  -> a
  -> [a]
diffs :: (Tower a -> Tower a) -> a -> [a]
diffs Tower a -> Tower a
f a
a = Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> Tower a -> [a]
forall a b. (a -> b) -> a -> b
$ (Tower a -> Tower a) -> a -> Tower a
forall a b. Num a => (Tower a -> b) -> a -> b
apply Tower a -> Tower a
f a
a
{-# INLINE diffs #-}

-- | Compute the zero-padded derivatives of a function @(a -> a)@
diffs0
  :: Num a
  => (Tower a -> Tower a)
  -> a
  -> [a]
diffs0 :: (Tower a -> Tower a) -> a -> [a]
diffs0 Tower a -> Tower a
f a
a = [a] -> [a]
forall a. Num a => [a] -> [a]
zeroPad ((Tower a -> Tower a) -> a -> [a]
forall a. Num a => (Tower a -> Tower a) -> a -> [a]
diffs Tower a -> Tower a
f a
a)
{-# INLINE diffs0 #-}

-- | Compute the answer and all derivatives of a function @(a -> f a)@
diffsF
  :: (Functor f, Num a)
  => (Tower a -> f (Tower a))
  -> a
  -> f [a]
diffsF :: (Tower a -> f (Tower a)) -> a -> f [a]
diffsF Tower a -> f (Tower a)
f a
a = Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> f (Tower a) -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tower a -> f (Tower a)) -> a -> f (Tower a)
forall a b. Num a => (Tower a -> b) -> a -> b
apply Tower a -> f (Tower a)
f a
a
{-# INLINE diffsF #-}

-- | Compute the zero-padded derivatives of a function @(a -> f a)@
diffs0F
  :: (Functor f, Num a)
  => (Tower a -> f (Tower a))
  -> a
  -> f [a]
diffs0F :: (Tower a -> f (Tower a)) -> a -> f [a]
diffs0F Tower a -> f (Tower a)
f a
a = [a] -> [a]
forall a. Num a => [a] -> [a]
zeroPad ([a] -> [a]) -> (Tower a -> [a]) -> Tower a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> f (Tower a) -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tower a -> f (Tower a)) -> a -> f (Tower a)
forall a b. Num a => (Tower a -> b) -> a -> b
apply Tower a -> f (Tower a)
f a
a
{-# INLINE diffs0F #-}

-- | @taylor f x@ compute the Taylor series of @f@ around @x@.
taylor
  :: Fractional a
  => (Tower a -> Tower a)
  -> a
  -> a
  -> [a]
taylor :: (Tower a -> Tower a) -> a -> a -> [a]
taylor Tower a -> Tower a
f a
x a
dx = a -> a -> [a] -> [a]
go a
1 a
1 ((Tower a -> Tower a) -> a -> [a]
forall a. Num a => (Tower a -> Tower a) -> a -> [a]
diffs Tower a -> Tower a
f a
x) where
  go :: a -> a -> [a] -> [a]
go !a
n !a
acc (a
a:[a]
as) = a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
acc a -> [a] -> [a]
forall a. a -> [a] -> [a]
: a -> a -> [a] -> [a]
go (a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) (a
acc a -> a -> a
forall a. Num a => a -> a -> a
* a
dx a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
n) [a]
as
  go a
_ a
_ [] = []

-- | @taylor0 f x@ compute the Taylor series of @f@ around @x@, zero-padded.
taylor0
  :: Fractional a
  => (Tower a -> Tower a)
  -> a
  -> a
  -> [a]
taylor0 :: (Tower a -> Tower a) -> a -> a -> [a]
taylor0 Tower a -> Tower a
f a
x a
dx = [a] -> [a]
forall a. Num a => [a] -> [a]
zeroPad ((Tower a -> Tower a) -> a -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> a -> [a]
taylor Tower a -> Tower a
f a
x a
dx)
{-# INLINE taylor0 #-}

-- | @maclaurin f@ compute the Maclaurin series of @f@
maclaurin
  :: Fractional a
  => (Tower a -> Tower a)
  -> a
  -> [a]
maclaurin :: (Tower a -> Tower a) -> a -> [a]
maclaurin Tower a -> Tower a
f = (Tower a -> Tower a) -> a -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> a -> [a]
taylor Tower a -> Tower a
f a
0
{-# INLINE maclaurin #-}

-- | @maclaurin f@ compute the Maclaurin series of @f@, zero-padded
maclaurin0
  :: Fractional a
  => (Tower a -> Tower a)
  -> a
  -> [a]
maclaurin0 :: (Tower a -> Tower a) -> a -> [a]
maclaurin0 Tower a -> Tower a
f = (Tower a -> Tower a) -> a -> a -> [a]
forall a. Fractional a => (Tower a -> Tower a) -> a -> a -> [a]
taylor0 Tower a -> Tower a
f a
0
{-# INLINE maclaurin0 #-}

-- | Compute the first derivative of a function @(a -> a)@
diff
  :: Num a
  => (Tower a -> Tower a)
  -> a
  -> a
diff :: (Tower a -> Tower a) -> a -> a
diff Tower a -> Tower a
f = [a] -> a
forall a. Num a => [a] -> a
d ([a] -> a) -> (a -> [a]) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tower a -> Tower a) -> a -> [a]
forall a. Num a => (Tower a -> Tower a) -> a -> [a]
diffs Tower a -> Tower a
f
{-# INLINE diff #-}

-- | Compute the answer and first derivative of a function @(a -> a)@
diff'
  :: Num a
  => (Tower a -> Tower a)
  -> a
  -> (a, a)
diff' :: (Tower a -> Tower a) -> a -> (a, a)
diff' Tower a -> Tower a
f = [a] -> (a, a)
forall a. Num a => [a] -> (a, a)
d' ([a] -> (a, a)) -> (a -> [a]) -> a -> (a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tower a -> Tower a) -> a -> [a]
forall a. Num a => (Tower a -> Tower a) -> a -> [a]
diffs Tower a -> Tower a
f
{-# INLINE diff' #-}

-- | Compute a directional derivative of a function @(f a -> a)@
du
  :: (Functor f, Num a)
  => (f (Tower a) -> Tower a)
  -> f (a, a) -> a
du :: (f (Tower a) -> Tower a) -> f (a, a) -> a
du f (Tower a) -> Tower a
f = [a] -> a
forall a. Num a => [a] -> a
d ([a] -> a) -> (f (a, a) -> [a]) -> f (a, a) -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> (f (a, a) -> Tower a) -> f (a, a) -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> Tower a
f (f (Tower a) -> Tower a)
-> (f (a, a) -> f (Tower a)) -> f (a, a) -> Tower a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, a) -> Tower a) -> f (a, a) -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, a) -> Tower a
forall a. (a, a) -> Tower a
withD
{-# INLINE du #-}

-- | Compute the answer and a directional derivative of a function @(f a -> a)@
du'
  :: (Functor f, Num a)
  => (f (Tower a) -> Tower a)
  -> f (a, a)
  -> (a, a)
du' :: (f (Tower a) -> Tower a) -> f (a, a) -> (a, a)
du' f (Tower a) -> Tower a
f = [a] -> (a, a)
forall a. Num a => [a] -> (a, a)
d' ([a] -> (a, a)) -> (f (a, a) -> [a]) -> f (a, a) -> (a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> (f (a, a) -> Tower a) -> f (a, a) -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> Tower a
f (f (Tower a) -> Tower a)
-> (f (a, a) -> f (Tower a)) -> f (a, a) -> Tower a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, a) -> Tower a) -> f (a, a) -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, a) -> Tower a
forall a. (a, a) -> Tower a
withD
{-# INLINE du' #-}

-- | Compute a directional derivative of a function @(f a -> g a)@
duF
  :: (Functor f, Functor g, Num a)
  => (f (Tower a) -> g (Tower a))
  -> f (a, a)
  -> g a
duF :: (f (Tower a) -> g (Tower a)) -> f (a, a) -> g a
duF f (Tower a) -> g (Tower a)
f = (Tower a -> a) -> g (Tower a) -> g a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([a] -> a
forall a. Num a => [a] -> a
d ([a] -> a) -> (Tower a -> [a]) -> Tower a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tower a -> [a]
forall a. Tower a -> [a]
getADTower) (g (Tower a) -> g a)
-> (f (a, a) -> g (Tower a)) -> f (a, a) -> g a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> g (Tower a)
f (f (Tower a) -> g (Tower a))
-> (f (a, a) -> f (Tower a)) -> f (a, a) -> g (Tower a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, a) -> Tower a) -> f (a, a) -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, a) -> Tower a
forall a. (a, a) -> Tower a
withD
{-# INLINE duF #-}

-- | Compute the answer and a directional derivative of a function @(f a -> g a)@
duF'
  :: (Functor f, Functor g, Num a)
  => (f (Tower a) -> g (Tower a))
  -> f (a, a)
  -> g (a, a)
duF' :: (f (Tower a) -> g (Tower a)) -> f (a, a) -> g (a, a)
duF' f (Tower a) -> g (Tower a)
f = (Tower a -> (a, a)) -> g (Tower a) -> g (a, a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([a] -> (a, a)
forall a. Num a => [a] -> (a, a)
d' ([a] -> (a, a)) -> (Tower a -> [a]) -> Tower a -> (a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tower a -> [a]
forall a. Tower a -> [a]
getADTower) (g (Tower a) -> g (a, a))
-> (f (a, a) -> g (Tower a)) -> f (a, a) -> g (a, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> g (Tower a)
f (f (Tower a) -> g (Tower a))
-> (f (a, a) -> f (Tower a)) -> f (a, a) -> g (Tower a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, a) -> Tower a) -> f (a, a) -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, a) -> Tower a
forall a. (a, a) -> Tower a
withD
{-# INLINE duF' #-}

-- | Given a function @(f a -> a)@, and a tower of derivatives, compute the corresponding directional derivatives.
dus
  :: (Functor f, Num a)
  => (f (Tower a) -> Tower a)
  -> f [a]
  -> [a]
dus :: (f (Tower a) -> Tower a) -> f [a] -> [a]
dus f (Tower a) -> Tower a
f = Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> (f [a] -> Tower a) -> f [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> Tower a
f (f (Tower a) -> Tower a)
-> (f [a] -> f (Tower a)) -> f [a] -> Tower a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> Tower a) -> f [a] -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> Tower a
forall a. [a] -> Tower a
tower
{-# INLINE dus #-}

-- | Given a function @(f a -> a)@, and a tower of derivatives, compute the corresponding directional derivatives, zero-padded
dus0
  :: (Functor f, Num a)
  => (f (Tower a) -> Tower a)
  -> f [a]
  -> [a]
dus0 :: (f (Tower a) -> Tower a) -> f [a] -> [a]
dus0 f (Tower a) -> Tower a
f = [a] -> [a]
forall a. Num a => [a] -> [a]
zeroPad ([a] -> [a]) -> (f [a] -> [a]) -> f [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tower a -> [a]
forall a. Tower a -> [a]
getADTower (Tower a -> [a]) -> (f [a] -> Tower a) -> f [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> Tower a
f (f (Tower a) -> Tower a)
-> (f [a] -> f (Tower a)) -> f [a] -> Tower a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> Tower a) -> f [a] -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> Tower a
forall a. [a] -> Tower a
tower
{-# INLINE dus0 #-}

-- | Given a function @(f a -> g a)@, and a tower of derivatives, compute the corresponding directional derivatives
dusF
  :: (Functor f, Functor g, Num a)
  => (f (Tower a) -> g (Tower a))
  -> f [a]
  -> g [a]
dusF :: (f (Tower a) -> g (Tower a)) -> f [a] -> g [a]
dusF f (Tower a) -> g (Tower a)
f = (Tower a -> [a]) -> g (Tower a) -> g [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tower a -> [a]
forall a. Tower a -> [a]
getADTower (g (Tower a) -> g [a]) -> (f [a] -> g (Tower a)) -> f [a] -> g [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> g (Tower a)
f (f (Tower a) -> g (Tower a))
-> (f [a] -> f (Tower a)) -> f [a] -> g (Tower a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> Tower a) -> f [a] -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> Tower a
forall a. [a] -> Tower a
tower
{-# INLINE dusF #-}

-- | Given a function @(f a -> g a)@, and a tower of derivatives, compute the corresponding directional derivatives, zero-padded
dus0F
  :: (Functor f, Functor g, Num a)
  => (f (Tower a) -> g (Tower a))
  -> f [a]
  -> g [a]
dus0F :: (f (Tower a) -> g (Tower a)) -> f [a] -> g [a]
dus0F f (Tower a) -> g (Tower a)
f = (Tower a -> [a]) -> g (Tower a) -> g [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Tower a -> [a]
forall a. Tower a -> [a]
getADTower (g (Tower a) -> g [a]) -> (f [a] -> g (Tower a)) -> f [a] -> g [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f (Tower a) -> g (Tower a)
f (f (Tower a) -> g (Tower a))
-> (f [a] -> f (Tower a)) -> f [a] -> g (Tower a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> Tower a) -> f [a] -> f (Tower a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [a] -> Tower a
forall a. [a] -> Tower a
tower
{-# INLINE dus0F #-}