backprop-0.0.3.0: Heterogeneous, type-safe automatic backpropagation in Haskell

Copyright(c) Justin Le 2017
LicenseBSD3
Maintainerjustin@jle.im
Stabilityexperimental
Portabilitynon-portable
Safe HaskellNone
LanguageHaskell2010

Numeric.Backprop.Mono.Implicit

Contents

Description

Offers full functionality for implicit-graph back-propagation with monomorphic inputs. The intended usage is to write a BPOp, which is a normal Haskell function from BVars to a result BVar. These BVars can be manipulated using their Num Fractional Floating instances.

The library can then perform back-propagation on the function (using backprop or grad) by using an implicitly built graph.

This is an "implicit-only" version of Numeric.Backprop.Mono, and a monomorphic version of Numeric.Backprop.Implicit, monomorphic in the sense that all of the inputs are of the same type.

Like for Numeric.Backprop.Implicit, this should actually be powerful enough for most use cases, but falls short because without explicit graph capabilities, recomputation can sometimes be inevitable. If the result of a function on BVars is used twice (like z in let z = x * y in z + z), this will allocate a new redundant graph node for every usage site of z. You can explicitly force z, but only using an explicit graph description using Numeric.Backprop.Mono.

Like Numeric.Backprop.Implicit, this can't handle sum types, but neither can Numeric.Backprop.Mono, so no loss here :)

This module implements pretty much the same functionality as Numeric.AD and Numeric.AD.Mode.Reverse from the ad package, because it uses the same implicit-graph back-propagation method. It can't compute jacobians/generalized gradients, however. This isn't a fundamental limitation of the implementaiton, though, but rather just a conscious design decision for this module's API.

Synopsis

Types

Backprop types

type BVar s n a = BVar s (Replicate n a) Source #

The basic unit of manipulation inside BP (or inside an implicit-graph backprop function). Instead of directly working with values, you work with BVars contating those values. When you work with a BVar, the backprop library can keep track of what values refer to which other values, and so can perform back-propagation to compute gradients.

A BVar s n r a refers to a value of type a, with an environment of n values of type r. The phantom parameter s is used to ensure that stray BVars don't leak outside of the backprop process.

(That is, if you're using implicit backprop, it ensures that you interact with BVars in a polymorphic way. And, if you're using explicit backprop, it ensures that a BVar s n r a never leaves the BP s n r that it was created in.)

BVars have Num, Fractional, Floating, etc. instances, so they can be manipulated using polymorphic functions and numeric functions in Haskell. You can add them, subtract them, etc., in "implicit" backprop style.

(However, note that if you directly manipulate BVars using those instances or using liftB, it delays evaluation, so every usage site has to re-compute the result/create a new node. If you want to re-use a BVar you created using + or - or liftB, use bindVar to force it first. See documentation for bindVar for more details.)

type BPOp n a b = forall s. VecT n (BVar s n a) a -> BVar s n a b Source #

An operation on BVars that can be backpropagated. A value of type:

BPOp n r a

takes a vector (VecT) of BVars containg n rs and uses them to (purely) produce a BVar containing an a.

foo :: BPOp N2 Double Double
foo (x :* y :* 'ØV') = x + sqrt y

BPOp here is related to BPOpI from the normal explicit-graph backprop module Numeric.Backprop.Mono.

type Op n a b = Op (Replicate n a) b Source #

An Op n a b describes a differentiable function from n values of type a to a value of type b.

For example, a value of type

Op N2 Int Double

is a function that takes two Ints and returns a Double. It can be differentiated to give a gradient of two Ints, if given a total derivative for the Double. Mathematically, it is akin to a:

\[ f : \mathbb{Z}^2 \rightarrow \mathbb{R} \]

See runOp, gradOp, and gradOpWith for examples on how to run it, and Op for instructions on creating it.

This type is abstracted over using the pattern synonym with constructor Op, so you can create one from scratch with it. However, it's simplest to create it using op2', op1', op2', and op3' helper smart constructors And, if your function is a numeric function, they can even be created automatically using op1, op2, op3, and opN with a little help from Numeric.AD from the ad library.

Note that this type is a subset or subtype of OpM (and also of OpB). So, if a function ever expects an OpM m as a (or a OpB), you can always provide an Op as a instead.

Many functions in this library will expect an OpM m as a (or an OpB s as a), and in all of these cases, you can provide an Op as a.

type OpB s n a b = OpB s (Replicate n a) b Source #

A subclass of OpM (and superclass of Op), representing Ops that the backprop library uses to perform backpropation.

An

OpB s n a b

represents a differentiable function that takes a n values of type a produces an a b, which can be run on BVar ss and also inside BP ss. For example, an OpB s N2 Double Bool takes two Doubles and produces a Bool, and does it in a differentiable way.

OpB is a superset of Op, so, if you see any function that expects an OpB (like opVar' and ~$, for example), you can give them an Op, as well.

You can think of OpB as a superclass/parent class of Op in this sense, and of Op as a subclass of OpB.

Vectors

See Numeric.Backprop.Mono for a mini-tutorial on VecT and Vec

data VecT k n f a :: forall k. N -> (k -> *) -> k -> * where #

Constructors

ØV :: VecT k Z f a 
(:*) :: VecT k (S n1) f a infixr 4 

Instances

Functor1 l l (VecT l n) 

Methods

map1 :: (forall a. f a -> g a) -> t f b -> t g b #

Foldable1 l l (VecT l n) 

Methods

foldMap1 :: Monoid m => (forall a. f a -> m) -> t f b -> m #

Traversable1 l l (VecT l n) 

Methods

traverse1 :: Applicative h => (forall a. f a -> h (g a)) -> t f b -> h (t g b) #

Witness ØC (Known N Nat n) (VecT k n f a) 

Associated Types

type WitnessC (ØC :: Constraint) (Known N Nat n :: Constraint) (VecT k n f a) :: Constraint #

Methods

(\\) :: ØC => (Known N Nat n -> r) -> VecT k n f a -> r #

(Monad f, Known N Nat n) => Monad (VecT * n f) 

Methods

(>>=) :: VecT * n f a -> (a -> VecT * n f b) -> VecT * n f b #

(>>) :: VecT * n f a -> VecT * n f b -> VecT * n f b #

return :: a -> VecT * n f a #

fail :: String -> VecT * n f a #

Functor f => Functor (VecT * n f) 

Methods

fmap :: (a -> b) -> VecT * n f a -> VecT * n f b #

(<$) :: a -> VecT * n f b -> VecT * n f a #

(Applicative f, Known N Nat n) => Applicative (VecT * n f) 

Methods

pure :: a -> VecT * n f a #

(<*>) :: VecT * n f (a -> b) -> VecT * n f a -> VecT * n f b #

(*>) :: VecT * n f a -> VecT * n f b -> VecT * n f b #

(<*) :: VecT * n f a -> VecT * n f b -> VecT * n f a #

Foldable f => Foldable (VecT * n f) 

Methods

fold :: Monoid m => VecT * n f m -> m #

foldMap :: Monoid m => (a -> m) -> VecT * n f a -> m #

foldr :: (a -> b -> b) -> b -> VecT * n f a -> b #

foldr' :: (a -> b -> b) -> b -> VecT * n f a -> b #

foldl :: (b -> a -> b) -> b -> VecT * n f a -> b #

foldl' :: (b -> a -> b) -> b -> VecT * n f a -> b #

foldr1 :: (a -> a -> a) -> VecT * n f a -> a #

foldl1 :: (a -> a -> a) -> VecT * n f a -> a #

toList :: VecT * n f a -> [a] #

null :: VecT * n f a -> Bool #

length :: VecT * n f a -> Int #

elem :: Eq a => a -> VecT * n f a -> Bool #

maximum :: Ord a => VecT * n f a -> a #

minimum :: Ord a => VecT * n f a -> a #

sum :: Num a => VecT * n f a -> a #

product :: Num a => VecT * n f a -> a #

Traversable f => Traversable (VecT * n f) 

Methods

traverse :: Applicative f => (a -> f b) -> VecT * n f a -> f (VecT * n f b) #

sequenceA :: Applicative f => VecT * n f (f a) -> f (VecT * n f a) #

mapM :: Monad m => (a -> m b) -> VecT * n f a -> m (VecT * n f b) #

sequence :: Monad m => VecT * n f (m a) -> m (VecT * n f a) #

Eq (f a) => Eq (VecT k n f a) 

Methods

(==) :: VecT k n f a -> VecT k n f a -> Bool #

(/=) :: VecT k n f a -> VecT k n f a -> Bool #

(Num (f a), Known N Nat n) => Num (VecT k n f a) 

Methods

(+) :: VecT k n f a -> VecT k n f a -> VecT k n f a #

(-) :: VecT k n f a -> VecT k n f a -> VecT k n f a #

(*) :: VecT k n f a -> VecT k n f a -> VecT k n f a #

negate :: VecT k n f a -> VecT k n f a #

abs :: VecT k n f a -> VecT k n f a #

signum :: VecT k n f a -> VecT k n f a #

fromInteger :: Integer -> VecT k n f a #

Ord (f a) => Ord (VecT k n f a) 

Methods

compare :: VecT k n f a -> VecT k n f a -> Ordering #

(<) :: VecT k n f a -> VecT k n f a -> Bool #

(<=) :: VecT k n f a -> VecT k n f a -> Bool #

(>) :: VecT k n f a -> VecT k n f a -> Bool #

(>=) :: VecT k n f a -> VecT k n f a -> Bool #

max :: VecT k n f a -> VecT k n f a -> VecT k n f a #

min :: VecT k n f a -> VecT k n f a -> VecT k n f a #

Show (f a) => Show (VecT k n f a) 

Methods

showsPrec :: Int -> VecT k n f a -> ShowS #

show :: VecT k n f a -> String #

showList :: [VecT k n f a] -> ShowS #

type WitnessC ØC (Known N Nat n) (VecT k n f a) 
type WitnessC ØC (Known N Nat n) (VecT k n f a) = ØC

type Vec n = VecT * n I #

newtype I a :: * -> * #

Constructors

I 

Fields

Instances

Monad I 

Methods

(>>=) :: I a -> (a -> I b) -> I b #

(>>) :: I a -> I b -> I b #

return :: a -> I a #

fail :: String -> I a #

Functor I 

Methods

fmap :: (a -> b) -> I a -> I b #

(<$) :: a -> I b -> I a #

Applicative I 

Methods

pure :: a -> I a #

(<*>) :: I (a -> b) -> I a -> I b #

(*>) :: I a -> I b -> I b #

(<*) :: I a -> I b -> I a #

Foldable I 

Methods

fold :: Monoid m => I m -> m #

foldMap :: Monoid m => (a -> m) -> I a -> m #

foldr :: (a -> b -> b) -> b -> I a -> b #

foldr' :: (a -> b -> b) -> b -> I a -> b #

foldl :: (b -> a -> b) -> b -> I a -> b #

foldl' :: (b -> a -> b) -> b -> I a -> b #

foldr1 :: (a -> a -> a) -> I a -> a #

foldl1 :: (a -> a -> a) -> I a -> a #

toList :: I a -> [a] #

null :: I a -> Bool #

length :: I a -> Int #

elem :: Eq a => a -> I a -> Bool #

maximum :: Ord a => I a -> a #

minimum :: Ord a => I a -> a #

sum :: Num a => I a -> a #

product :: Num a => I a -> a #

Traversable I 

Methods

traverse :: Applicative f => (a -> f b) -> I a -> f (I b) #

sequenceA :: Applicative f => I (f a) -> f (I a) #

mapM :: Monad m => (a -> m b) -> I a -> m (I b) #

sequence :: Monad m => I (m a) -> m (I a) #

Witness p q a => Witness p q (I a) 

Associated Types

type WitnessC (p :: Constraint) (q :: Constraint) (I a) :: Constraint #

Methods

(\\) :: p => (q -> r) -> I a -> r #

Eq a => Eq (I a) 

Methods

(==) :: I a -> I a -> Bool #

(/=) :: I a -> I a -> Bool #

Num a => Num (I a) 

Methods

(+) :: I a -> I a -> I a #

(-) :: I a -> I a -> I a #

(*) :: I a -> I a -> I a #

negate :: I a -> I a #

abs :: I a -> I a #

signum :: I a -> I a #

fromInteger :: Integer -> I a #

Ord a => Ord (I a) 

Methods

compare :: I a -> I a -> Ordering #

(<) :: I a -> I a -> Bool #

(<=) :: I a -> I a -> Bool #

(>) :: I a -> I a -> Bool #

(>=) :: I a -> I a -> Bool #

max :: I a -> I a -> I a #

min :: I a -> I a -> I a #

Show a => Show (I a) 

Methods

showsPrec :: Int -> I a -> ShowS #

show :: I a -> String #

showList :: [I a] -> ShowS #

type WitnessC p q (I a) 
type WitnessC p q (I a) = Witness p q a

back-propagation

backprop :: forall n a b. (Num a, Known Nat n) => BPOp n a b -> Vec n a -> (b, Vec n a) Source #

Run back-propagation on a BPOp function, getting both the result and the gradient of the result with respect to the inputs.

foo :: BPOp N2 Double Double
foo (x :* y :* ØV) =
  let z = x * sqrt y
  in  z + x ** y
>>> backprop foo (2 :+ 3 :+ ØV)
(11.46, 13.73 :+ 6.12 :+ ØV)

grad :: forall n a b. (Num a, Known Nat n) => BPOp n a b -> Vec n a -> Vec n a Source #

Run the BPOp on an input tuple and return the gradient of the result with respect to the input tuple.

foo :: BPOp N2 Double Double
foo (x :* y :* ØV) =
  let z = x * sqrt y
  in  z + x ** y
>>> grad foo (2 :+ 3 :+ ØV)
13.73 :+ 6.12 :+ ØV

eval :: forall n a b. (Num a, Known Nat n) => BPOp n a b -> Vec n a -> b Source #

Simply run the BPOp on an input tuple, getting the result without bothering with the gradient or with back-propagation.

foo :: BPOp N2 Double Double
foo (x :* y :* ØV) =
  let z = x * sqrt y
  in  z + x ** y
>>> eval foo (2 :+ 3 :+ ØV)
11.46

Var manipulation

constVar :: a -> BVar s n r a Source #

Create a BVar that represents just a specific value, that doesn't depend on any other BVars.

liftB :: forall s m n a b r. OpB s m a b -> VecT m (BVar s n r) a -> BVar s n r b Source #

Apply OpB over a VecT of BVars, as inputs. Provides "implicit" back-propagation, with deferred evaluation.

If you had an OpB s N3 a b, this function will expect a vector of of three BVar s n r as, and the result will be a BVar s n r b:

myOp :: OpB s N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

x :* y :* z :* 'ØV'              :: VecT N3 (BVar s n r) a
liftB myOp (x :* y :* z :* ØV) :: BVar s n r b

Note that OpB is a superclass of Op, so you can provide any Op here, as well (like those created by op1, op2, constOp, op0 etc.)

liftB has an infix alias, .$, so the above example can also be written as:

myOp .$ (x :* y :* z :* ØV) :: BVar s n r b

to let you pretend that you're applying the myOp function to three inputs.

The result is a new deferred BVar. This should be fine in most cases, unless you use the result in more than one location. This will cause evaluation to be duplicated and multiple redundant graph nodes to be created. If you need to use it in two locations, you should use opVar instead of liftB, or use bindVar:

opVar o xs = bindVar (liftB o xs)

liftB can be thought of as a "deferred evaluation" version of opVar.

(.$) :: forall s m n a b r. OpB s m a b -> VecT m (BVar s n r) a -> BVar s n r b Source #

Infix synonym for liftB, which lets you pretend that you're applying OpBs as if they were functions:

myOp :: OpB s N3 a b
x    :: BVar s n r a
y    :: BVar s n r a
z    :: BVar s n r a

x :* y :* z :* 'ØV'              :: VecT N3 (BVar s n r) a
myOp .$ (x :* y :* z :* ØV) :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in any Op here, as well (like those created by op1, op2, constOp, op0 etc.)

See the documentation for liftB for all the caveats of this usage.

.$ can also be thought of as a "deferred evaluation" version of ~$:

o ~$ xs = bindVar (o .$ xs)

liftB1 :: OpB s N1 a a -> BVar s n r a -> BVar s n r a Source #

Convenient wrapper over liftB that takes an OpB with one argument and a single BVar argument. Lets you not have to type out the entire VecT.

liftB1 o x = liftB o (x :* 'ØV')

myOp :: Op N2 a b
x    :: BVar s n r a

liftB1 myOp x :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op1) as well.

See the documentation for liftB for caveats and potential problematic situations with this.

liftB2 :: OpB s N2 a a -> BVar s n r a -> BVar s n r a -> BVar s n r a Source #

Convenient wrapper over liftB that takes an OpB with two arguments and two BVar arguments. Lets you not have to type out the entire VecT.

liftB2 o x y = liftB o (x :* y :* 'ØV')

myOp :: Op N2 a b
x    :: BVar s n r a
y    :: BVar s n r b

liftB2 myOp x y :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op2) as well.

See the documentation for liftB for caveats and potential problematic situations with this.

liftB3 :: OpB s N3 a a -> BVar s n r a -> BVar s n r a -> BVar s n r a -> BVar s n r a Source #

Convenient wrapper over liftB that takes an OpB with three arguments and three BVar arguments. Lets you not have to type out the entire Prod.

liftB3 o x y z = liftB o (x :* y :* z :* 'ØV')

myOp :: Op N3 a b
x    :: BVar s n r a
y    :: BVar s n r b
z    :: BVar s n r b

liftB3 myOp x y z :: BVar s n r b

Note that OpB is a superclass of Op, so you can pass in an Op here (like one made with op3) as well.

See the documentation for liftB for caveats and potential problematic situations with this.

Op

op1 :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> Op N1 a a Source #

Automatically create an Op of a numerical function taking one argument. Uses diff, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (op1 (recip . negate)) (5 :+ ØV)
(-0.2, 0.04 :+ ØV)

op2 :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a -> Reverse s a) -> Op N2 a a Source #

Automatically create an Op of a numerical function taking two arguments. Uses grad, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (op2 (\x y -> x * sqrt y)) (3 :+ 4 :+ ØV)
(6.0, 2.0 :+ 0.75 :+ ØV)

op3 :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a -> Reverse s a -> Reverse s a) -> Op N3 a a Source #

Automatically create an Op of a numerical function taking three arguments. Uses grad, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (op3 (\x y z -> (x * sqrt y)**z)) (3 :+ 4 :+ 2 :+ ØV)
(36.0, 24.0 :+ 9.0 :+ 64.503 :+ ØV)

opN :: (Num a, Known Nat n) => (forall s. Reifies s Tape => Vec n (Reverse s a) -> Reverse s a) -> Op n a a Source #

Automatically create an Op of a numerical function taking multiple arguments. Uses grad, and so can take any numerical function polymorphic over the standard numeric types.

>>> gradOp' (opN (\(x :+ y :+ Ø) -> x * sqrt y)) (3 :+ 4 :+ ØV)
(6.0, 2.0 :+ 0.75 :+ ØV)

Utility

pattern (:+) :: forall a n. a -> Vec n a -> Vec (S n) a infixr 4 #

(*:) :: f a -> f a -> VecT k (S (S Z)) f a infix 5 #

(+:) :: a -> a -> Vec (S (S Z)) a infix 5 #

head' :: VecT k (S n) f a -> f a #

Nat type synonyms

type N0 = Z #

Convenient aliases for low-value Peano numbers.

type N1 = S N0 #

type N2 = S N1 #

type N3 = S N2 #

type N4 = S N3 #

type N5 = S N4 #

type N6 = S N5 #

type N7 = S N6 #

type N8 = S N7 #

type N9 = S N8 #

type N10 = S N9 #

Numeric Ops

Optimized ops for numeric functions. See Numeric.Backprop.Op.Mono for more information.

(+.) :: Num a => Op N2 a a Source #

Optimized version of op2 (+).

(-.) :: Num a => Op N2 a a Source #

Optimized version of op2 (-).

(*.) :: Num a => Op N2 a a Source #

Optimized version of op2 (*).

negateOp :: Num a => Op N1 a a Source #

Optimized version of op1 negate.

absOp :: Num a => Op N1 a a Source #

Optimized version of op1 abs.

signumOp :: Num a => Op N1 a a Source #

Optimized version of op1 signum.

(/.) :: Fractional a => Op N2 a a Source #

Optimized version of op2 (/).

recipOp :: Fractional a => Op N1 a a Source #

Optimized version of op1 recip.

expOp :: Floating a => Op N1 a a Source #

Optimized version of op1 exp.

logOp :: Floating a => Op N1 a a Source #

Optimized version of op1 log.

sqrtOp :: Floating a => Op N1 a a Source #

Optimized version of op1 sqrt.

(**.) :: Floating a => Op N2 a a Source #

Optimized version of op2 (**).

logBaseOp :: Floating a => Op N2 a a Source #

Optimized version of op2 logBase.

sinOp :: Floating a => Op N1 a a Source #

Optimized version of op1 sin.

cosOp :: Floating a => Op N1 a a Source #

Optimized version of op1 cos.

tanOp :: Floating a => Op N1 a a Source #

Optimized version of op1 tan.

asinOp :: Floating a => Op N1 a a Source #

Optimized version of op1 asin.

acosOp :: Floating a => Op N1 a a Source #

Optimized version of op1 acos.

atanOp :: Floating a => Op N1 a a Source #

Optimized version of op1 atan.

sinhOp :: Floating a => Op N1 a a Source #

Optimized version of op1 sinh.

coshOp :: Floating a => Op N1 a a Source #

Optimized version of op1 cosh.

tanhOp :: Floating a => Op N1 a a Source #

Optimized version of op1 tanh.

asinhOp :: Floating a => Op N1 a a Source #

Optimized version of op1 asinh.

acoshOp :: Floating a => Op N1 a a Source #

Optimized version of op1 acosh.

atanhOp :: Floating a => Op N1 a a Source #

Optimized version of op1 atanh.