backprop-0.0.1.0: Heterogeneous, type-safe automatic backpropagation in Haskell

Copyright(c) Justin Le 2017
LicenseBSD3
Maintainerjustin@jle.im
Stabilityexperimental
Portabilitynon-portable
Safe HaskellNone
LanguageHaskell2010

Numeric.Backprop.Mono

Contents

Description

Provides a monomorphic interface to the library and to the Numeric.Backprop module.

They are monomorphic in the sense that all of the inputs have to be of the same type. So, something like

BP s '[Double, Double, Double] Int

From Numeric.Backprop would, in this module, be:

BP s N3 Double Int

Instead of dealing with Prods and Tuples, this module works with VecTs and Vecs, respectively. These are fixed-length vectors whose length are encoded in their types, constructed with :* (for VecT) or :+ (for Vec).

Most of the concepts in normal heterogeneous backprop (for Numeric.Backprop) should apply here as well, so you can look at any of the tutorials or examples and repurpose them to work here. Just remember to convert something like Op '[a, a] b to Op N2 a b.

As a comparison, this implements something similar in functionality to Numeric.AD and Numeric.AD.Mode.Reverse from the ad package, in that they both offer monomorphic automatic differentiation through back-propagation. This module doesn't allow the computation of jacobians or generalized gradients for \(\mathbb{R}^N \rightarrow \mathbb{R}^M\) functions. This module only computs gradients for (mathbb{R}^N rightarrow mathbb{R})-like functions. This is more of a conscious design decision in the API of this module rather than a fundamental limitation of the implementation.

This module also allows you to build explicit data dependency graphs so the library can reduce duplication and perform optimizations, which may or may not provide advantages over Numeric.AD.Mode.Reverse's unsafePerformIO-based implicit graph building.

Synopsis

Types

Backprop types

type BP s n r = BP s (Replicate n r) Source #

A Monad allowing you to explicitly build hetereogeneous data dependency graphs and that the library can perform back-propagation on.

A BP s n r a is a BP action that uses an environment n values of type r, and returns an a. When "run", it will compute a gradient that is a vector (Vec) of n rs. (The phantom parameter s is used to ensure that any BVars aren't leaked out of the monad)

Note that you can only "run" a BP s n r that produces a BVar -- that is, things of the form

BP s n r (BVar n r a)

The above is a BP action that returns a BVar containing an a. When this is run, it'll produce a result of type a and a gradient of that is a vector of n values of type r. (This form has a type synonym, BPOp, for convenience)

For example, BP s N3 Double is a monad that represents a computation with three Doubles as inputs. And, if you ran a

BP s N3 Double (BVar N3 Double Int)

Or, using the BPOp type synonym:

BPOp s N3 Double Int

with backprop or gradBPOp, it'll return a gradient on the inputs (a vector of three Doubles) and produce a value of type Int.

Now, one powerful thing about this type is that a BP is itself an Op (or more precisely, an OpM). So, once you create your fancy BP computation, you can transform it into an OpM using bpOp.

type BPOp s n r a = BP s n r (BVar s n r a) Source #

A handy type synonym representing a BP action that returns a BVar. This is handy because this is the form of BP actions that backprop and gradBPOp (etc.) expects.

A value of type:

BPOp s n r a

is an action that takes an input environment of n values of type r and produces a BVar containing a value of type a. Because it returns a BVar, the library can track the data dependencies between the BVar and the input environment and perform back-propagation.

See documentation for BP for an explanation of the phantom type parameter s.

type BPOpI s n r a = VecT n (BVar s n r) r -> BVar s n r a Source #

An "implicit" operation on BVars that can be backpropagated. A value of type:

BPOpI s n r a

takes a vector (Vec) of n of BVars containg rs and uses them to (purely) produce a BVar containing an a.

foo :: BPOpI s N2 Double Double
foo (x :* y :* ØV) = x + sqrt y

If you are exclusively doing implicit back-propagation by combining BVars and using BPOpIs, you are probably better off just importing Numeric.Backprop.Mono.Implicit, which provides better tools. This type synonym exists in Numeric.Backprop.Mono just for the implicitly function, which can convert "implicit" backprop functions like a BPOpI s rs a into an "explicit" graph backprop function, a BPOp s rs a.

type BVar s n a = BVar s (Replicate n a) Source #

The basic unit of manipulation inside BP (or inside an implicit-graph backprop function). Instead of directly working with values, you work with BVars contating those values. When you work with a BVar, the backprop library can keep track of what values refer to which other values, and so can perform back-propagation to compute gradients.

A BVar s n r a refers to a value of type a, with an environment of n values of type r. The phantom parameter s is used to ensure that stray BVars don't leak outside of the backprop process.

(That is, if you're using implicit backprop, it ensures that you interact with BVars in a polymorphic way. And, if you're using explicit backprop, it ensures that a BVar s n r a never leaves the BP s n r that it was created in.)

BVars have Num, Fractional, Floating, etc. instances, so they can be manipulated using polymorphic functions and numeric functions in Haskell. You can add them, subtract them, etc., in "implicit" backprop style.

(However, note that if you directly manipulate BVars using those instances or using liftB, it delays evaluation, so every usage site has to re-compute the result/create a new node. If you want to re-use a BVar you created using + or - or liftB, use bindVar to force it first. See documentation for bindVar for more details.)

type Op n a b = Op (Replicate n a) b Source #

An Op n a b describes a differentiable function from n values of type a to a value of type b.

For example, a value of type

Op N2 Int Double

is a function that takes two Ints and returns a Double. It can be differentiated to give a gradient of two Ints, if given a total derivative for the Double. Mathematically, it is akin to a:

\[ f : \mathbb{Z}^2 \rightarrow \mathbb{R} \]

See runOp, gradOp, and gradOpWith for examples on how to run it, and Op for instructions on creating it.

This type is abstracted over using the pattern synonym with constructor Op, so you can create one from scratch with it. However, it's simplest to create it using op2', op1', op2', and op3' helper smart constructors And, if your function is a numeric function, they can even be created automatically using op1, op2, op3, and opN with a little help from Numeric.AD from the ad library.

Note that this type is a subset or subtype of OpM (and also of OpB). So, if a function ever expects an OpM m as a (or a OpB), you can always provide an Op as a instead.

Many functions in this library will expect an OpM m as a (or an OpB s as a), and in all of these cases, you can provide an Op as a.

type OpB s n a b = OpB s (Replicate n a) b Source #

A subclass of OpM (and superclass of Op), representing Ops that the backprop library uses to perform backpropation.

An

OpB s n a b

represents a differentiable function that takes a n values of type a produces an a b, which can be run on BVar ss and also inside BP ss. For example, an OpB s N2 Double Bool takes two Doubles and produces a Bool, and does it in a differentiable way.

OpB is a superset of Op, so, if you see any function that expects an OpB (like opVar' and ~$, for example), you can give them an Op, as well.

You can think of OpB as a superclass/parent class of Op in this sense, and of Op as a subclass of OpB.

Vectors

A VecT is a fixed-length list of a given type. It's basically the "monomorphic" version of a Prod (see the mini-tutorial in Numeric.Backprop).

A VecT n f a is a list of n f as, and is constructed by consing them together with :* (using 'ØV' as nil):

I "hello" :* I "world" :* I "ok" :* ØV :: VecT N3 I String
[1,2,3] :* [4,5,6,7] :* ØV             :: VecT N2 [] Int

(I is the identity functor)

So, in general:

x :: f a
y :: f a
z :: f a
k :: f a
x :* y :* z :* k :* ØV :: VecT f N4 a

Vec is provided as a convenient type synonym for VecT I, and has a convenient pattern synonym :+, which can also be used for pattern matching:

x :: a
y :: a
z :: a
k :: a

x ::< y ::< z ::< k ::< ØV :: Vec N4 a

data VecT k n f a :: forall k. N -> (k -> *) -> k -> * where #

Constructors

ØV :: VecT k Z f a 
(:*) :: VecT k (S n1) f a infixr 4 

Instances

Functor1 l l (VecT l n) 

Methods

map1 :: (forall a. f a -> g a) -> t f b -> t g b #

Foldable1 l l (VecT l n) 

Methods

foldMap1 :: Monoid m => (forall a. f a -> m) -> t f b -> m #

Traversable1 l l (VecT l n) 

Methods

traverse1 :: Applicative h => (forall a. f a -> h (g a)) -> t f b -> h (t g b) #

Witness ØC (Known N Nat n) (VecT k n f a) 

Associated Types

type WitnessC (ØC :: Constraint) (Known N Nat n :: Constraint) (VecT k n f a) :: Constraint #

Methods

(\\) :: ØC => (Known N Nat n -> r) -> VecT k n f a -> r #

(Monad f, Known N Nat n) => Monad (VecT * n f) 

Methods

(>>=) :: VecT * n f a -> (a -> VecT * n f b) -> VecT * n f b #

(>>) :: VecT * n f a -> VecT * n f b -> VecT * n f b #

return :: a -> VecT * n f a #

fail :: String -> VecT * n f a #

Functor f => Functor (VecT * n f) 

Methods

fmap :: (a -> b) -> VecT * n f a -> VecT * n f b #

(<$) :: a -> VecT * n f b -> VecT * n f a #

(Applicative f, Known N Nat n) => Applicative (VecT * n f) 

Methods

pure :: a -> VecT * n f a #

(<*>) :: VecT * n f (a -> b) -> VecT * n f a -> VecT * n f b #

(*>) :: VecT * n f a -> VecT * n f b -> VecT * n f b #

(<*) :: VecT * n f a -> VecT * n f b -> VecT * n f a #

Foldable f => Foldable (VecT * n f) 

Methods

fold :: Monoid m => VecT * n f m -> m #

foldMap :: Monoid m => (a -> m) -> VecT * n f a -> m #

foldr :: (a -> b -> b) -> b -> VecT * n f a -> b #

foldr' :: (a -> b -> b) -> b -> VecT * n f a -> b #

foldl :: (b -> a -> b) -> b -> VecT * n f a -> b #

foldl' :: (b -> a -> b) -> b -> VecT * n f a -> b #

foldr1 :: (a -> a -> a) -> VecT * n f a -> a #

foldl1 :: (a -> a -> a) -> VecT * n f a -> a #

toList :: VecT * n f a -> [a] #

null :: VecT * n f a -> Bool #

length :: VecT * n f a -> Int #

elem :: Eq a => a -> VecT * n f a -> Bool #

maximum :: Ord a => VecT * n f a -> a #

minimum :: Ord a => VecT * n f a -> a #

sum :: Num a => VecT * n f a -> a #

product :: Num a => VecT * n f a -> a #

Traversable f => Traversable (VecT * n f) 

Methods

traverse :: Applicative f => (a -> f b) -> VecT * n f a -> f (VecT * n f b) #

sequenceA :: Applicative f => VecT * n f (f a) -> f (VecT * n f a) #

mapM :: Monad m => (a -> m b) -> VecT * n f a -> m (VecT * n f b) #

sequence :: Monad m => VecT * n f (m a) -> m (VecT * n f a) #

Eq (f a) => Eq (VecT k n f a) 

Methods

(==) :: VecT k n f a -> VecT k n f a -> Bool #

(/=) :: VecT k n f a -> VecT k n f a -> Bool #

(Num (f a), Known N Nat n) => Num (VecT k n f a) 

Methods

(+) :: VecT k n f a -> VecT k n f a -> VecT k n f a #

(-) :: VecT k n f a -> VecT k n f a -> VecT k n f a #

(*) :: VecT k n f a -> VecT k n f a -> VecT k n f a #

negate :: VecT k n f a -> VecT k n f a #

abs :: VecT k n f a -> VecT k n f a #

signum :: VecT k n f a -> VecT k n f a #

fromInteger :: Integer -> VecT k n f a #

Ord (f a) => Ord (VecT k n f a) 

Methods

compare :: VecT k n f a -> VecT k n f a -> Ordering #

(<) :: VecT k n f a -> VecT k n f a -> Bool #

(<=) :: VecT k n f a -> VecT k n f a -> Bool #

(>) :: VecT k n f a -> VecT k n f a -> Bool #

(>=) :: VecT k n f a -> VecT k n f a -> Bool #

max :: VecT k n f a -> VecT k n f a -> VecT k n f a #

min :: VecT k n f a -> VecT k n f a -> VecT k n f a #

Show (f a) => Show (VecT k n f a) 

Methods

showsPrec :: Int -> VecT k n f a -> ShowS #

show :: VecT k n f a -> String #

showList :: [VecT k n f a] -> ShowS #

type WitnessC ØC (Known N Nat n) (VecT k n f a) 
type WitnessC ØC (Known N Nat n) (VecT k n f a) = ØC

type Vec n = VecT * n I #

newtype I a :: * -> * #

Constructors

I 

Fields

Instances

Monad I 

Methods

(>>=) :: I a -> (a -> I b) -> I b #

(>>) :: I a -> I b -> I b #

return :: a -> I a #

fail :: String -> I a #

Functor I 

Methods

fmap :: (a -> b) -> I a -> I b #

(<$) :: a -> I b -> I a #

Applicative I 

Methods

pure :: a -> I a #

(<*>) :: I (a -> b) -> I a -> I b #

(*>) :: I a -> I b -> I b #

(<*) :: I a -> I b -> I a #

Foldable I 

Methods

fold :: Monoid m => I m -> m #

foldMap :: Monoid m => (a -> m) -> I a -> m #

foldr :: (a -> b -> b) -> b -> I a -> b #

foldr' :: (a -> b -> b) -> b -> I a -> b #

foldl :: (b -> a -> b) -> b -> I a -> b #

foldl' :: (b -> a -> b) -> b -> I a -> b #

foldr1 :: (a -> a -> a) -> I a -> a #

foldl1 :: (a -> a -> a) -> I a -> a #

toList :: I a -> [a] #

null :: I a -> Bool #

length :: I a -> Int #

elem :: Eq a => a -> I a -> Bool #

maximum :: Ord a => I a -> a #

minimum :: Ord a => I a -> a #

sum :: Num a => I a -> a #

product :: Num a => I a -> a #

Traversable I 

Methods

traverse :: Applicative f => (a -> f b) -> I a -> f (I b) #

sequenceA :: Applicative f => I (f a) -> f (I a) #

mapM :: Monad m => (a -> m b) -> I a -> m (I b) #

sequence :: Monad m => I (m a) -> m (I a) #

Witness p q a => Witness p q (I a) 

Associated Types

type WitnessC (p :: Constraint) (q :: Constraint) (I a) :: Constraint #

Methods

(\\) :: p => (q -> r) -> I a -> r #

Eq a => Eq (I a) 

Methods

(==) :: I a -> I a -> Bool #

(/=) :: I a -> I a -> Bool #

Num a => Num (I a) 

Methods

(+) :: I a -> I a -> I a #

(-) :: I a -> I a -> I a #

(*) :: I a -> I a -> I a #

negate :: I a -> I a #

abs :: I a -> I a #

signum :: I a -> I a #

fromInteger :: Integer -> I a #

Ord a => Ord (I a) 

Methods

compare :: I a -> I a -> Ordering #

(<) :: I a -> I a -> Bool #

(<=) :: I a -> I a -> Bool #

(>) :: I a -> I a -> Bool #

(>=) :: I a -> I a -> Bool #

max :: I a -> I a -> I a #

min :: I a -> I a -> I a #

Show a => Show (I a) 

Methods

showsPrec :: Int -> I a -> ShowS #

show :: I a -> String #

showList :: [I a] -> ShowS #

type WitnessC p q (I a) 
type WitnessC p q (I a) = Witness p q a

BP

Backprop

backprop :: forall n r a. Num r => (forall s. BPOp s n r a) -> Vec n r -> (a, Vec n r) Source #

Perform back-propagation on the given BPOp. Returns the result of the operation it represents, as well as the gradient of the result with respect to its inputs. See module header for Numeric.Backprop.Mono and package documentation for examples and usages.

evalBPOp :: forall n r a. (forall s. BPOp s n r a) -> Vec n r -> a Source #

Simply run the BPOp on an input vector, getting the result without bothering with the gradient or with back-propagation.

gradBPOp :: forall n r a. Num r => (forall s. BPOp s n r a) -> Vec n r -> Vec n r Source #

Run the BPOp on an input vector and return the gradient of the result with respect to the input vector

Utility combinators

withInps :: Known Nat n => (VecT n (BVar s n r) r -> BP s n r a) -> BP s n r a Source #

Runs a continuation on a Vec of all of the input BVars.

Handy for bringing the environment into scope and doing stuff with it:

foo :: BPOp N2 Double Int
foo = withInps $ \(x :* y :* ØV) -> do
    -- do stuff with inputs

Looks kinda like foo (x :* y *+ ØV) = -- ..., don't it?

Note that the above is the same as

foo :: BPOp N2 Double Int
foo = do
    case inpVars of
      x :* y :* ØV -> do
        -- do stuff with inputs

But just a little nicer!

implicitly :: Known Nat n => BPOpI s n r a -> BPOp s n r a Source #

Convert a BPOpI into a BPOp. That is, convert a function on a bundle of BVars (generating an implicit graph) into a fully fledged BPOp that you can run backprop on. See BPOpI for more information.

If you are going to write exclusively using implicit BVar operations, it might be more convenient to use Numeric.Backprop.Mono.Implicit instead, which is geared around that use case.

Vars

constVar :: a -> BVar s n r a Source #

Create a BVar that represents just a specific value, that doesn't depend on any other BVars.

inpVar :: Fin n -> BVar s n r r Source #

Create a BVar given an index (Fin) into the input environment. For an example,

inpVar FZ

would refer to the first input variable, Bool]@), and

inpVar (FS FZ)

Would refer to the second input variable.

Typically, there shouldn't be any reason to use inpVar directly. It's cleaner to get all of your input BVars together using withInps or inpVars.

inpVars :: Known Nat n => VecT n (BVar s n r) r Source #

Get a VecT (vector) of BVars for all of the input environment (the n rs) of the BP s n r

For example, if your BP has two Doubles inside its input environment (a BP s N2 Double), this would return two BVars, pointing to each input Double.

case (inpVars :: VecT N2 (BVar s N2 Double) Double) of
  x :* y :* ØV -> do
    -- the first item, x, is a var to the first input
    x :: BVar s N2 Double
    -- the second item, y, is a var to the second input
    y :: BVar s N2 Double

bpOp :: forall s n r a. (Num r, Known Nat n) => BPOp s n r a -> OpB s n r a Source #

Turn a BPOp into an OpB. Basically converts a BP taking n rs and producing an a into an Op taking an n rs and returning an a, with all of the powers and utility of an Op, including all of its gradient-finding glory.

Really just reveals the fact that any BPOp s rs a is itself an Op, an OpB s rs a, which makes it a differentiable function.

Handy because an OpB can be used with almost all of the Op-related functions in this moduel, including opVar, ~$, etc.

bindVar :: forall s n r a. Num a => BVar s n r a -> BP s n r (BVar s n r a) Source #

Concretizes a delayed BVar. If you build up a BVar using numeric functions like + or * or using liftB, it'll defer the evaluation, and all of its usage sites will create a separate graph node.

Use bindVar if you ever intend to use a BVar in more than one location.

-- bad
errSquared :: Num a => BP s N2 a a
errSquared = withInp $ \(x :* y :* Ø) -> do
    let err = r - t
    return (err * err)   -- err is used twice!

-- good
errSquared :: Num a => BP s N2 a a
errSquared = withInp $ \(x :* y :* Ø) -> do
    let err = r - t
    e <- bindVar err     -- force e, so that it's safe to use twice!
    return (e * e)

-- better
errSquared :: Num a => BP s N2 a a
errSquared = withInp $ \(x :* y :* Ø) -> do
    let err = r - t
    e <- bindVar err
    bindVar (e * e)      -- result is forced so user doesn't have to worry

Note the relation to opVar ~$ liftB / .$:

opVar o xs    = bindVar (liftB o xs)
o ~$ xs       = bindVar (o .$ xs)
op2 (*) ~$ (x :< y :< Ø) = bindVar (x * y)

So you can avoid bindVar altogether if you use the explicitly binding ~$ and opVar etc.

Note that bindVar on BVars that are already forced is a no-op.

From Ops

opVar :: forall s m n r a b. Num b => OpB s m a b -> VecT m (BVar s n r) a -> BP s n r (BVar s n r b) Source #

Apply an OpB to a VecT (vector) of BVars.

If you had an OpB s N3 a b, this function will expect a vector of of three BVar s n r as, and the result will be a BVar s n r b:

myOp :: OpB s N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

x :* y :* z :* 'ØV'              :: VecT N3 (BVar s n r) a
opVar myOp (x :* y :* z :* ØV) :: BP s n r (BVar s n r b)

Note that OpB is a superclass of Op, so you can provide any Op here, as well (like those created by op1, op2, constOp, op0 etc.)

opVar has an infix alias, ~$, so the above example can also be written as:

myOp ~$ (x :* y :* z :* ØV) :: BP s n r (BVar s n r b)

to let you pretend that you're applying the myOp function to three inputs.

Also note the relation between opVar and liftB and bindVar:

opVar o xs = bindVar (liftB o xs)

opVar can be thought of as a "binding" version of liftB.

(~$) :: forall s m n r a b. Num b => OpB s m a b -> VecT m (BVar s n r) a -> BP s n r (BVar s n r b) infixr 5 Source #

Infix synonym for opVar, which lets you pretend that you're applying OpBs as if they were functions:

myOp :: OpB s N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

x :* y :* z :* 'ØV'              :: VecT N3 (BVar s n r) a
myOp ~$ (x :* y :* z :* ØV) :: BP s n r (BVar s n r b)

Note that OpB is a superclass of Op, so you can pass in any Op here, as well (like those created by op1, op2, constOp, op0 etc.)

~$ can also be thought of as a "binding" version of .$:

o ~$ xs = bindVar (o .$ xs)

opVar1 :: forall s n r a b. Num b => OpB s N1 a b -> BVar s n r a -> BP s n r (BVar s n r b) Source #

Convenient wrapper over opVar that takes an OpB with one argument and a single BVar argument. Lets you not have to type out the entire VecT.

opVar1 o x = opVar o (x :* 'ØV')

myOp :: Op N2 a b
x    :: BVar s n r a

opVar1 myOp x :: BP s n r (BVar s n r b)

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op1) as well.

opVar2 :: forall s n r a b. Num b => OpB s N2 a b -> BVar s n r a -> BVar s n r a -> BP s n r (BVar s n r b) Source #

Convenient wrapper over opVar that takes an OpB with two arguments and two BVar arguments. Lets you not have to type out the entire VecT.

opVar2 o x y = opVar o (x :* y :* 'ØV')

myOp :: Op N2 a b
x    :: BVar s n r a
y    :: BVar s n r b

opVar2 myOp x y :: BP s n r (BVar s n r b)

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op2) as well.

opVar3 :: forall s n r a b. Num b => OpB s N3 a b -> BVar s n r a -> BVar s n r a -> BVar s n r a -> BP s n r (BVar s n r b) Source #

Convenient wrapper over opVar that takes an OpB with three arguments and three BVar arguments. Lets you not have to type out the entire VecT.

opVar3 o x y z = opVar o (x :* y :* z :* 'ØV')

myOp :: Op N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

opVar3 myOp x y z :: BP s n r (BVar s n r b)

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op3) as well.

(-$) :: forall s m n r a b. (Num a, Num b, Known Nat m) => BPOp s m a b -> VecT m (BVar s n r) a -> BP s n r (BVar s n r b) infixr 5 Source #

Lets you treat a BPOp s n a b as an Op n a b, and "apply" arguments to it just like you would with an Op and ~$ / opVar.

Basically a convenient wrapper over bpOp and ~$:

o -$ xs = bpOp o ~$ xs

So for a BPOp s n a b, you can "plug in" BVars to each a, and get a b as a result.

Useful for running a BPOp s n a b that you got from a different function, and "plugging in" its a inputs with BVars from your current environment.

Combining

liftB :: forall s m n a b r. OpB s m a b -> VecT m (BVar s n r) a -> BVar s n r b Source #

Apply OpB over a VecT of BVars, as inputs. Provides "implicit" back-propagation, with deferred evaluation.

If you had an OpB s N3 a b, this function will expect a vector of of three BVar s n r as, and the result will be a BVar s n r b:

myOp :: OpB s N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

x :* y :* z :* 'ØV'              :: VecT N3 (BVar s n r) a
liftB myOp (x :* y :* z :* ØV) :: BVar s n r b

Note that OpB is a superclass of Op, so you can provide any Op here, as well (like those created by op1, op2, constOp, op0 etc.)

liftB has an infix alias, .$, so the above example can also be written as:

myOp .$ (x :* y :* z :* ØV) :: BVar s n r b

to let you pretend that you're applying the myOp function to three inputs.

The result is a new deferred BVar. This should be fine in most cases, unless you use the result in more than one location. This will cause evaluation to be duplicated and multiple redundant graph nodes to be created. If you need to use it in two locations, you should use opVar instead of liftB, or use bindVar:

opVar o xs = bindVar (liftB o xs)

liftB can be thought of as a "deferred evaluation" version of opVar.

(.$) :: forall s m n a b r. OpB s m a b -> VecT m (BVar s n r) a -> BVar s n r b Source #

Infix synonym for liftB, which lets you pretend that you're applying OpBs as if they were functions:

myOp :: OpB s N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

x :* y :* z :* 'ØV'              :: VecT N3 (BVar s n r) a
myOp .$ (x :* y :* z :* ØV) :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in any Op here, as well (like those created by op1, op2, constOp, op0 etc.)

See the documentation for liftB for all the caveats of this usage.

.$ can also be thought of as a "deferred evaluation" version of ~$:

o ~$ xs = bindVar (o .$ xs)

liftB1 :: OpB s N1 a a -> BVar s n r a -> BVar s n r a Source #

Convenient wrapper over liftB that takes an OpB with one argument and a single BVar argument. Lets you not have to type out the entire VecT.

liftB1 o x = liftB o (x :* 'ØV')

myOp :: Op N2 a b
x    :: BVar s n r a

liftB1 myOp x :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op1) as well.

See the documentation for liftB for caveats and potential problematic situations with this.

liftB2 :: OpB s N2 a a -> BVar s n r a -> BVar s n r a -> BVar s n r a Source #

Convenient wrapper over liftB that takes an OpB with two arguments and two BVar arguments. Lets you not have to type out the entire VecT.

liftB2 o x y = liftB o (x :* y :* 'ØV')

myOp :: Op N2 a b
x    :: BVar s n r a
y    :: BVar s n r b

liftB2 myOp x y :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op2) as well.

See the documentation for liftB for caveats and potential problematic situations with this.

liftB3 :: OpB s N3 a a -> BVar s n r a -> BVar s n r a -> BVar s n r a -> BVar s n r a Source #

Convenient wrapper over liftB that takes an OpB with three arguments and three BVar arguments. Lets you not have to type out the entire Prod.

liftB3 o x y z = liftB o (x :* y :* z :* 'ØV')

myOp :: Op N3 a b
x    :: BVar s n r a
y    :: BVar s n r b
z    :: BVar s n r b

liftB3 myOp x y z :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op3) as well.

See the documentation for liftB for caveats and potential problematic situations with this.

Op

op1 :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> Op N1 a a Source #

Automatically create an Op of a numerical function taking one argument. Uses diff, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (op1 (recip . negate)) (5 :+ ØV)
(-0.2, 0.04 :+ ØV)

op2 :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a -> Reverse s a) -> Op N2 a a Source #

Automatically create an Op of a numerical function taking two arguments. Uses grad, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (op2 (\x y -> x * sqrt y)) (3 :+ 4 :+ ØV)
(6.0, 2.0 :+ 0.75 :+ ØV)

op3 :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a -> Reverse s a -> Reverse s a) -> Op N3 a a Source #

Automatically create an Op of a numerical function taking three arguments. Uses grad, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (op3 (\x y z -> (x * sqrt y)**z)) (3 :+ 4 :+ 2 :+ ØV)
(36.0, 24.0 :+ 9.0 :+ 64.503 :+ ØV)

opN :: (Num a, Known Nat n) => (forall s. Reifies s Tape => Vec n (Reverse s a) -> Reverse s a) -> Op n a a Source #

Automatically create an Op of a numerical function taking multiple arguments. Uses grad, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (opN (\(x :+ y :+ Ø) -> x * sqrt y)) (3 :+ 4 :+ ØV)
(6.0, 2.0 :+ 0.75 :+ ØV)

composeOp :: forall m n o a b c. (Monad m, Num a, Known Nat n) => VecT o (OpM m n a) b -> OpM m o b c -> OpM m n a c Source #

Compose OpMs together, similar to .. But, because all OpMs are \(\mathbb{R}^N \rightarrow \mathbb{R}\), this is more like sequence for functions, or liftAN.

That is, given an o of OpM m n a bs, it can compose them with an OpM m o b c to create an OpM m o a c.

composeOp1 :: forall m n a b c. (Monad m, Num a, Known Nat n) => OpM m n a b -> OpM m N1 b c -> OpM m n a c Source #

Convenient wrappver over composeOp for the case where the second function only takes one input, so the two OpMs can be directly piped together, like for ..

(~.) :: forall m n a b c. (Monad m, Num a, Known Nat n) => OpM m N1 b c -> OpM m n a b -> OpM m n a c infixr 9 Source #

Convenient infix synonym for (flipped) composeOp1. Meant to be used just like .:

op1 negate            :: Op '[a]   a
op2 (+)               :: Op '[a,a] a

op1 negate ~. op2 (+) :: Op '[a, a] a

op1' :: (a -> (b, Maybe b -> a)) -> Op N1 a b Source #

Create an Op of a function taking one input, by giving its explicit derivative. The function should return a tuple containing the result of the function, and also a function taking the derivative of the result and return the derivative of the input.

If we have

\[ \eqalign{ f &: \mathbb{R} \rightarrow \mathbb{R}\cr y &= f(x)\cr z &= g(y) } \]

Then the derivative \( \frac{dz}{dx} \), it would be:

\[ \frac{dz}{dx} = \frac{dz}{dy} \frac{dy}{dx} \]

If our Op represents \(f\), then the second item in the resulting tuple should be a function that takes \(\frac{dz}{dy}\) and returns \(\frac{dz}{dx}\).

If the input is Nothing, then \(\frac{dz}{dy}\) should be taken to be \(1\).

As an example, here is an Op that squares its input:

square :: Num a => Op N1 a a
square = op1' $ \x -> (x*x, \case Nothing -> 2 * x
                                  Just d  -> 2 * d * x
                      )

Remember that, generally, end users shouldn't directly construct Ops; they should be provided by libraries or generated automatically.

For numeric functions, single-input Ops can be generated automatically using op1.

op2' :: (a -> a -> (b, Maybe b -> (a, a))) -> Op N2 a b Source #

Create an Op of a function taking two inputs, by giving its explicit gradient. The function should return a tuple containing the result of the function, and also a function taking the derivative of the result and return the derivative of the input.

If we have

\[ \eqalign{ f &: \mathbb{R}^2 \rightarrow \mathbb{R}\cr z &= f(x, y)\cr k &= g(z) } \]

Then the gradient \( \left< \frac{\partial k}{\partial x}, \frac{\partial k}{\partial y} \right> \) would be:

\[ \left< \frac{\partial k}{\partial x}, \frac{\partial k}{\partial y} \right> = \left< \frac{dk}{dz} \frac{\partial z}{dx}, \frac{dk}{dz} \frac{\partial z}{dy} \right> \]

If our Op represents \(f\), then the second item in the resulting tuple should be a function that takes \(\frac{dk}{dz}\) and returns \( \left< \frac{\partial k}{dx}, \frac{\partial k}{dx} \right> \).

If the input is Nothing, then \(\frac{dk}{dz}\) should be taken to be \(1\).

As an example, here is an Op that multiplies its inputs:

mul :: Num a => Op N2 a a
mul = op2' $ \x y -> (x*y, \case Nothing -> (y  , x  )
                                 Just d  -> (d*y, x*d)
                     )

Remember that, generally, end users shouldn't directly construct Ops; they should be provided by libraries or generated automatically.

For numeric functions, two-input Ops can be generated automatically using op2.

op3' :: (a -> a -> a -> (b, Maybe b -> (a, a, a))) -> Op N3 a b Source #

Create an Op of a function taking three inputs, by giving its explicit gradient. See documentation for op2' for more details.

Utility

pattern (:+) :: forall a n. a -> Vec n a -> Vec (S n) a infixr 4 #

(*:) :: f a -> f a -> VecT k (S (S Z)) f a infix 5 #

(+:) :: a -> a -> Vec (S (S Z)) a infix 5 #

head' :: VecT k (S n) f a -> f a #

Nat type synonyms

type N0 = Z #

Convenient aliases for low-value Peano numbers.

type N1 = S N0 #

type N2 = S N1 #

type N3 = S N2 #

type N4 = S N3 #

type N5 = S N4 #

type N6 = S N5 #

type N7 = S N6 #

type N8 = S N7 #

type N9 = S N8 #

type N10 = S N9 #