ad-delcont-0.1.0.0: Reverse-mode automatic differentiation with delimited continuations
Safe HaskellSafe-Inferred
LanguageHaskell2010

Numeric.AD.DelCont

Description

Reverse-mode automatic differentiation using delimited continuations.

Quickstart

Most users will only need to import rad1, rad2 and leverage the Num, Fractional, Floating instances of the AD type.

Similarly to ad, a user supplies a polymorphic function to be differentiated, e.g.

f :: Num a => a -> a
f x = x + (x * x)

and the library takes care of the rest :

>>> rad1 f 1.2
(2.6399999999999997,3.4000000000000004)

It's important to emphasize that the library cannot differentiate functions of concrete types, e.g. Double -> Double. On the other hand, it's easy to experiment with other numerical interfaces that support one, zero and plus.

Advanced usage

The library is small and easily extensible.

For example, a user might want to supply their own numerical typeclass other than Num, and build up a library of AD combinators based on that, specializing op1 and op2 with custom implementations of zero, one and plus. This insight first appeared in the user interface of backprop, as the Backprop typeclass.

Exposing unconstrained AD combinators lets users specialize this library to e.g. exotic number-like types or discrete data structures such as dictionaries, automata etc.

Implementation details and design choices

This is the first (known) Haskell implementation of the ideas presented in Wang et al. Here the role of variable mutation and delimited continuations is made explicit by the use of ST and ContT, as compared to the reference Scala implementation.

ad-delcont relies on non-standard interpretation of the user-provided function; in order to compute the adjoint values (the sensitivities) of the function parameters, the function is first evaluated ("forwards"), while keeping track of continuation points, and all the intermediate adjoints are accumulated upon returning from the respective continuations ("backwards") via safe mutation in the ST monad.

As a result of this design, the main AD type cannot be given Eq and Ord instances (since it's unclear how equality and ordering predicates would apply to continuations and state threads).

The user interface is inspired by that of ad and backprop, however the internals are completely different in that this library doesn't reify the function to be differentiated into a "tape" data structure.

Another point in common with backprop is that users can differentiate heterogeneous functions: the input and output types can be different. This makes it possible to differentiate functions of statically-typed vectors and matrices.

References

Synopsis
  • rad1 :: (Num a, Num b) => (forall s. AD' s a -> AD' s b) -> a -> (b, a)
  • rad2 :: (Num a, Num b, Num c) => (forall s. AD' s a -> AD' s b -> AD' s c) -> a -> b -> (c, (a, b))
  • auto :: a -> da -> AD s a da
  • rad1g :: da -> db -> (forall s. AD s a da -> AD s b db) -> a -> (b, da)
  • rad2g :: da -> db -> dc -> (forall s. AD s a da -> AD s b db -> AD s c dc) -> a -> b -> (c, (da, db))
  • op1 :: db -> (da -> da -> da) -> (a -> (b, db -> da)) -> AD s a da -> AD s b db
  • op2 :: dc -> (da -> da -> da) -> (db -> db -> db) -> (a -> b -> (c, dc -> da, dc -> db)) -> AD s a da -> AD s b db -> AD s c dc
  • data AD s a da
  • type AD' s a = AD s a a

Quickstart

rad1 Source #

Arguments

:: (Num a, Num b) 
=> (forall s. AD' s a -> AD' s b)

function to be differentiated

-> a

function argument

-> (b, a)

(result, adjoint)

Evaluate (forward mode) and differentiate (reverse mode) a unary function

>>> rad1 (\x -> x * x) 1
(1, 2)

rad2 Source #

Arguments

:: (Num a, Num b, Num c) 
=> (forall s. AD' s a -> AD' s b -> AD' s c)

function to be differentiated

-> a 
-> b 
-> (c, (a, b))

(result, adjoints)

Evaluate (forward mode) and differentiate (reverse mode) a binary function

>>> rad2 (\x y -> x + y + y) 1 1
(1,2)
>>> rad2 (\x y -> (x + y) * x) 3 2
(15,(8,3))

auto Source #

Arguments

:: a

primal

-> da

adjoint (in most cases this can be set to (0 :: a))

-> AD s a da 

Lift a constant into AD

Advanced usage

rad1g Source #

Arguments

:: da

zero

-> db

one

-> (forall s. AD s a da -> AD s b db) 
-> a

function argument

-> (b, da)

(result, adjoint)

Evaluate (forward mode) and differentiate (reverse mode) a unary function, without committing to a specific numeric typeclass

rad2g Source #

Arguments

:: da

zero

-> db

zero

-> dc

one

-> (forall s. AD s a da -> AD s b db -> AD s c dc) 
-> a 
-> b 
-> (c, (da, db))

(result, adjoints)

Evaluate (forward mode) and differentiate (reverse mode) a binary function, without committing to a specific numeric typeclass

Lift functions into AD

op1 Source #

Arguments

:: db

zero

-> (da -> da -> da)

plus

-> (a -> (b, db -> da))

returns : (function result, pullback)

-> AD s a da 
-> AD s b db 

Lift a unary function

The first two arguments constrain the types of the adjoint values of the output and input variable respectively, see op1Num for an example.

The third argument is the most interesting: it specifies at once how to compute the function value and how to compute the sensitivity with respect to the function parameter.

Note : the type parameters are completely unconstrained.

op2 Source #

Arguments

:: dc

zero

-> (da -> da -> da)

plus

-> (db -> db -> db)

plus

-> (a -> b -> (c, dc -> da, dc -> db))

returns : (function result, pullbacks)

-> AD s a da -> AD s b db -> AD s c dc 

Lift a binary function

See op1 for more information.

Num instances

Types

data AD s a da Source #

Mutable references to dual numbers in the continuation monad

Here the a and da type parameters are respectively the primal and dual quantities tracked by the AD computation.

Instances

Instances details
Floating a => Floating (AD s a a) Source # 
Instance details

Defined in Numeric.AD.DelCont.Internal

Methods

pi :: AD s a a #

exp :: AD s a a -> AD s a a #

log :: AD s a a -> AD s a a #

sqrt :: AD s a a -> AD s a a #

(**) :: AD s a a -> AD s a a -> AD s a a #

logBase :: AD s a a -> AD s a a -> AD s a a #

sin :: AD s a a -> AD s a a #

cos :: AD s a a -> AD s a a #

tan :: AD s a a -> AD s a a #

asin :: AD s a a -> AD s a a #

acos :: AD s a a -> AD s a a #

atan :: AD s a a -> AD s a a #

sinh :: AD s a a -> AD s a a #

cosh :: AD s a a -> AD s a a #

tanh :: AD s a a -> AD s a a #

asinh :: AD s a a -> AD s a a #

acosh :: AD s a a -> AD s a a #

atanh :: AD s a a -> AD s a a #

log1p :: AD s a a -> AD s a a #

expm1 :: AD s a a -> AD s a a #

log1pexp :: AD s a a -> AD s a a #

log1mexp :: AD s a a -> AD s a a #

Fractional a => Fractional (AD s a a) Source # 
Instance details

Defined in Numeric.AD.DelCont.Internal

Methods

(/) :: AD s a a -> AD s a a -> AD s a a #

recip :: AD s a a -> AD s a a #

fromRational :: Rational -> AD s a a #

Num a => Num (AD s a a) Source #

The numerical methods of (Num, Fractional, Floating etc.) can be read off their backprop counterparts : https://hackage.haskell.org/package/backprop-0.2.6.4/docs/src/Numeric.Backprop.Op.html#%2A.

Instance details

Defined in Numeric.AD.DelCont.Internal

Methods

(+) :: AD s a a -> AD s a a -> AD s a a #

(-) :: AD s a a -> AD s a a -> AD s a a #

(*) :: AD s a a -> AD s a a -> AD s a a #

negate :: AD s a a -> AD s a a #

abs :: AD s a a -> AD s a a #

signum :: AD s a a -> AD s a a #

fromInteger :: Integer -> AD s a a #

type AD' s a = AD s a a Source #

Like AD but the types of primal and dual coincide