Copyright | (c) Edward Kmett 2010-2015 |
---|---|

License | BSD3 |

Maintainer | ekmett@gmail.com |

Stability | experimental |

Portability | GHC only |

Safe Haskell | None |

Language | Haskell2010 |

This module provides reverse-mode Automatic Differentiation using post-hoc linear time topological sorting.

For reverse mode AD we use `StableName`

to recover sharing information from
the tape to avoid combinatorial explosion, and thus run asymptotically faster
than it could without such sharing information, but the use of side-effects
contained herein is benign.

- data AD s a
- data Kahn a
- auto :: Mode t => Scalar t -> t
- grad :: (Traversable f, Num a) => (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> f a
- grad' :: (Traversable f, Num a) => (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> (a, f a)
- gradWith :: (Traversable f, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> f b
- gradWith' :: (Traversable f, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> (a, f b)
- jacobian :: (Traversable f, Functor g, Num a) => (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (f a)
- jacobian' :: (Traversable f, Functor g, Num a) => (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (a, f a)
- jacobianWith :: (Traversable f, Functor g, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (f b)
- jacobianWith' :: (Traversable f, Functor g, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (a, f b)
- hessian :: (Traversable f, Num a) => (forall s. f (AD s (On (Kahn (Kahn a)))) -> AD s (On (Kahn (Kahn a)))) -> f a -> f (f a)
- hessianF :: (Traversable f, Functor g, Num a) => (forall s. f (AD s (On (Kahn (Kahn a)))) -> g (AD s (On (Kahn (Kahn a))))) -> f a -> g (f (f a))
- diff :: Num a => (forall s. AD s (Kahn a) -> AD s (Kahn a)) -> a -> a
- diff' :: Num a => (forall s. AD s (Kahn a) -> AD s (Kahn a)) -> a -> (a, a)
- diffF :: (Functor f, Num a) => (forall s. AD s (Kahn a) -> f (AD s (Kahn a))) -> a -> f a
- diffF' :: (Functor f, Num a) => (forall s. AD s (Kahn a) -> f (AD s (Kahn a))) -> a -> f (a, a)

# Documentation

Bounded a => Bounded (AD s a) Source | |

Enum a => Enum (AD s a) Source | |

Eq a => Eq (AD s a) Source | |

Floating a => Floating (AD s a) Source | |

Fractional a => Fractional (AD s a) Source | |

Num a => Num (AD s a) Source | |

Ord a => Ord (AD s a) Source | |

Read a => Read (AD s a) Source | |

Real a => Real (AD s a) Source | |

RealFloat a => RealFloat (AD s a) Source | |

RealFrac a => RealFrac (AD s a) Source | |

Show a => Show (AD s a) Source | |

Erf a => Erf (AD s a) Source | |

InvErf a => InvErf (AD s a) Source | |

Mode a => Mode (AD s a) Source | |

type Scalar (AD s a) = Scalar a Source |

`Kahn`

is a `Mode`

using reverse-mode automatic differentiation that provides fast `diffFU`

, `diff2FU`

, `grad`

, `grad2`

and a fast `jacobian`

when you have a significantly smaller number of outputs than inputs.

(Num a, Bounded a) => Bounded (Kahn a) | |

(Num a, Enum a) => Enum (Kahn a) | |

(Num a, Eq a) => Eq (Kahn a) | |

Floating a => Floating (Kahn a) | |

Fractional a => Fractional (Kahn a) | |

Num a => Num (Kahn a) | |

(Num a, Ord a) => Ord (Kahn a) | |

Real a => Real (Kahn a) | |

RealFloat a => RealFloat (Kahn a) | |

RealFrac a => RealFrac (Kahn a) | |

Show a => Show (Kahn a) Source | |

MuRef (Kahn a) Source | |

Erf a => Erf (Kahn a) | |

InvErf a => InvErf (Kahn a) | |

Num a => Mode (Kahn a) Source | |

Num a => Jacobian (Kahn a) Source | |

Num a => Grad (Kahn a) [a] (a, [a]) a Source | |

Grad i o o' a => Grad (Kahn a -> i) (a -> o) (a -> o') a Source | |

type DeRef (Kahn a) = Tape a Source | |

type Scalar (Kahn a) = a Source | |

type D (Kahn a) = Id a Source |

# Gradient

grad :: (Traversable f, Num a) => (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> f a Source

The `grad`

function calculates the gradient of a non-scalar-to-scalar function with kahn-mode AD in a single pass.

`>>>`

[2,1,1]`grad (\[x,y,z] -> x*y+z) [1,2,3]`

grad' :: (Traversable f, Num a) => (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> (a, f a) Source

The `grad'`

function calculates the result and gradient of a non-scalar-to-scalar function with kahn-mode AD in a single pass.

`>>>`

(28.566231899122155,[29.5562243957226,29.5562243957226,-0.1411200080598672])`grad' (\[x,y,z] -> 4*x*exp y+cos z) [1,2,3]`

gradWith :: (Traversable f, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> f b Source

gradWith' :: (Traversable f, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> AD s (Kahn a)) -> f a -> (a, f b) Source

# Jacobian

jacobian :: (Traversable f, Functor g, Num a) => (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (f a) Source

The `jacobian`

function calculates the jacobian of a non-scalar-to-non-scalar function with kahn AD lazily in `m`

passes for `m`

outputs.

`>>>`

[[0,1],[1,0],[1,2]]`jacobian (\[x,y] -> [y,x,x*y]) [2,1]`

`>>>`

[[0.0,7.38905609893065],[-0.8414709848078965,0.0],[1.0,1.0]]`jacobian (\[x,y] -> [exp y,cos x,x+y]) [1,2]`

jacobian' :: (Traversable f, Functor g, Num a) => (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (a, f a) Source

The `jacobian'`

function calculates both the result and the Jacobian of a nonscalar-to-nonscalar function, using `m`

invocations of kahn AD,
where `m`

is the output dimensionality. Applying `fmap snd`

to the result will recover the result of `jacobian`

| An alias for `gradF'`

ghci> jacobian' ([x,y] -> [y,x,x*y]) [2,1] [(1,[0,1]),(2,[1,0]),(2,[1,2])]

jacobianWith :: (Traversable f, Functor g, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (f b) Source

'jacobianWith g f' calculates the Jacobian of a non-scalar-to-non-scalar function `f`

with kahn AD lazily in `m`

passes for `m`

outputs.

Instead of returning the Jacobian matrix, the elements of the matrix are combined with the input using the `g`

.

`jacobian`

=`jacobianWith`

(_ dx -> dx)`jacobianWith`

`const`

= (f x ->`const`

x`<$>`

f x)

jacobianWith' :: (Traversable f, Functor g, Num a) => (a -> a -> b) -> (forall s. f (AD s (Kahn a)) -> g (AD s (Kahn a))) -> f a -> g (a, f b) Source

`jacobianWith`

g f' calculates both the result and the Jacobian of a nonscalar-to-nonscalar function `f`

, using `m`

invocations of kahn AD,
where `m`

is the output dimensionality. Applying `fmap snd`

to the result will recover the result of `jacobianWith`

Instead of returning the Jacobian matrix, the elements of the matrix are combined with the input using the `g`

.

`jacobian'`

==`jacobianWith'`

(_ dx -> dx)

# Hessian

hessian :: (Traversable f, Num a) => (forall s. f (AD s (On (Kahn (Kahn a)))) -> AD s (On (Kahn (Kahn a)))) -> f a -> f (f a) Source

Compute the `hessian`

via the `jacobian`

of the gradient. gradient is computed in `Kahn`

mode and then the `jacobian`

is computed in `Kahn`

mode.

However, since the

is square this is not as fast as using the forward-mode `grad`

f :: f a -> f a`jacobian`

of a reverse mode gradient provided by `hessian`

.

`>>>`

[[0,1],[1,0]]`hessian (\[x,y] -> x*y) [1,2]`

hessianF :: (Traversable f, Functor g, Num a) => (forall s. f (AD s (On (Kahn (Kahn a)))) -> g (AD s (On (Kahn (Kahn a))))) -> f a -> g (f (f a)) Source

Compute the order 3 Hessian tensor on a non-scalar-to-non-scalar function via the `Kahn`

-mode Jacobian of the `Kahn`

-mode Jacobian of the function.

Less efficient than `hessianF`

.

`>>>`

[[[0.0,1.0],[1.0,0.0]],[[0.0,0.0],[0.0,0.0]],[[-1.1312043837568135,-2.4717266720048188],[-2.4717266720048188,1.1312043837568135]]]`hessianF (\[x,y] -> [x*y,x+y,exp x*cos y]) [1,2]`

# Derivatives

diff :: Num a => (forall s. AD s (Kahn a) -> AD s (Kahn a)) -> a -> a Source

Compute the derivative of a function.

`>>>`

1.0`diff sin 0`

`>>>`

1.0`cos 0`

diff' :: Num a => (forall s. AD s (Kahn a) -> AD s (Kahn a)) -> a -> (a, a) Source

The `diff'`

function calculates the value and derivative, as a
pair, of a scalar-to-scalar function.

`>>>`

(0.0,1.0)`diff' sin 0`