---
title: Manual Gradients
---
Providing Hand-Written Gradients
================================
```haskell top hide
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
import Data.Functor.Identity
import qualified Data.List
import GHC.Generics (Generic)
import GHC.TypeNats
import Inliterate.Import
import Lens.Micro
import Lens.Micro.TH
import Numeric.Backprop
import Numeric.Backprop.Class
import Numeric.LinearAlgebra.Static (L, R, konst)
import System.Random
import qualified Data.Vector as V
import qualified Numeric.LinearAlgebra.Static as H
import qualified Numeric.LinearAlgebra as HU
```
Providing and writing hand-written gradients for operations can be useful if
you are [peforming low-level optimizations][performance] or [equipping your
library for backprop][equipping].
[performance]: https://backprop.jle.im/07-performance.html
[equipping]: https://backprop.jle.im/08-equipping-your-library.html
Ideally, as an *end user*, you should never have to do this. The whole point
of the *backprop* library is to allow you to use backpropagatable functions as
normal functions, and to let you build complicated functions by simply
composing normal Haskell functions, where the *backprop* library automatically
infers your gradients.
However, if you are writing a library, you probably need to provide "primitive"
backpropagatable functions (like matrix-vector multiplication for a linear
algebra library) for your users, so your users can then use those primitive
functions to write their own code, without ever having to be aware of any
gradients.
If you are writing code and recognize some bottlenecks related to library
overhead as [described in this post][performance], then you might also want to
provide manual gradients as a last resort. However, this should always be a
last resort, as *figuring out* manual gradients is a tedious and error-prone
process that can introduce subtle bugs in ways that don't always appear in
testing. It also makes your code much more fragile and difficult to refactor
and shuffle around (since you aren't using normal function composition and
application anymore) and much harder to read. Only proceed if you decide that
the huge cognitive costs are worth it.
The Lifted Function
-------------------
A lifted function of type
```haskell
myFunc :: Reifies s W => BVar s a -> BVar s b
```
represents a backpropagatble function taking an `a` and returning a `b`. It is
represented as a function taking a `BVar` containing an `a` and returning a
`BVar` containing a `b`; the `BVar s` with the `Reifies s W` is what allows for
tracking of backpropagation.
A `BVar s a -> BVar s b` is really, actually, under the hood:
```haskell
type BVar s a -> BVar s b
= a -> (b, b -> a)
```
That is, given an input `a`, you get:
1. A `b`, the result (the "forward pass")
2. A `b -> a`, the "scaled gradient" function.
A full technical description is given in the documentation for [Numeric.Backprop.Op][op].
[op]: http://hackage.haskell.org/package/backprop/docs/Numeric-Backprop-Op.html
The `b` result is simple enough; it's the result of your function. The "scaled
gradient" function requires some elaboration. Let's say you are writing a
lifted version of your function \\(y = f(x)\\) (whose derivative is
\\(\frac{dy}{dx}\\)), and that your *final result* at the end of your
computation is \\(z = g(f(x))\\) (whose derivative is \\(\frac{dz}{dx}\\)). In
that case, because of the chain rule, \\(\frac{dz}{dx} = \frac{dz}{dy}
\frac{dy}{dx}\\).
The scaled gradient `b -> a` is the function which, *given*
\\(\frac{dy}{dz}\\) `:: b`, *returns* \\(\frac{dz}{dx}\\) `:: a`. (that is,
returns \\(\frac{dz}{dy} \frac{dy}{dx}\\) `:: a`).
For example, for the mathematical operation \\(y = f(x) = x^2\\), then,
considering \\(z = g(f(x))\\), \\(\frac{dz}{dx} = \frac{dz}{dy} 2x\\). In fact,
for all functions taking and returning scalars (just normal single numbers),
\\(\frac{dz}{dx} = \frac{dz}{dy} f'(x)\\).
Simple Example
--------------
With that in mind, let's a lifted "squared" operation, that takes `x` and
returns `x^2`:
```haskell top
square
:: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
square = liftOp1 . op1 $ \x ->
( x^2 , \dzdy -> dzdy * 2 * x)
-- ^- actual result ^- scaled gradient function
```
We can write one for `sin`, as well. For \\(y = f(x) = \sin(x)\\), we consider
\\(z = g(f(x))\\) to see \\(\frac{dz}{dx} = \frac{dz}{dy} \cos(x)\\). So, we
have:
```haskell top
liftedSin
:: (Floating a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
liftedSin = liftOp1 . op1 $ \x ->
( sin x, \dzdy -> dzdy * cos x )
```
In general, for functions that take and return scalars:
```haskell
liftedF
:: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
liftedF = liftOp1 . op1 $ \x ->
( f x, \dzdy -> dzdy * dfdx x )
```
For an example of every single numeric function in base Haskell, see [the
source of Op.hs][opsource] for the `Op` definitions for every method in `Num`,
`Fractional`, and `Floating`.
[opsource]: https://github.com/mstksg/backprop/blob/a7651b4549048a3aca73c79c6fbe07c3e8ee500e/src/Numeric/Backprop/Op.hs#L646-L787
Non-trivial example
-------------------
A simple non-trivial example is `sumElements`, which we can define to take the
*hmatrix* library's `R n` type (an n-vector of `Double`). In this case, we
have to think about \\(g(\mathrm{sum}(\mathbf{x}))\\). In this case, the types
guide our thinking:
```haskell
sumElements :: R n -> Double
sumElementsScaledGrad :: R n -> Double -> R n
```
The simplest way for me to do this personally is to just take it element by
element.
1. *Write out the functions in question, in a simple example*
In our case:
* \\(y = f(\langle a, b, c \rangle) = a + b + c\\)
* \\(z = g(y) = g(a + b + c)\\)
2. *Identify the components in your gradient*
In our case, we have to return a gradient \\(\langle \frac{\partial z}{\partial a},
\frac{\partial z}{\partial b}, \frac{\partial z}{\partial c} \rangle\\).
3. *Work out each component of the gradient until you start to notice a
pattern*
Let's start with \\(\frac{\partial z}{\partial a}\\). We need to find
\\(\frac{\partial z}{\partial a}\\) in terms of \\(\frac{dz}{dy}\\):
* Through the chain rule, \\(\frac{\partial z}{\partial a} =
\frac{dz}{dy} \frac{\partial y}{\partial a}\\).
* Because \\(y = a + b + c\\), we know that \\(\frac{\partial y}{\partial
a} = 1\\).
* Because \\(\frac{\partial y}{\partial a} = 1\\), we know that
\\(\frac{\partial z}{\partial a} = \frac{dz}{dy} \times 1 =
\frac{dz}{dy}\\).
So, our expression of \\(\frac{\partial z}{\partial a}\\) in terms of
\\(\frac{dz}{dy}\\) is simple -- it's simply \\(\frac{\partial z}{\partial
a} = \frac{dz}{dy}\\).
Now, let's look at \\(\frac{\partial z}{\partial b}\\). We need to find
\\(\frac{\partial z}{\partial b}\\) in terms of \\(\frac{dz}{dy}\\).
* Through the chain rule, \\(\frac{\partial z}{\partial b} =
\frac{dz}{dy} \frac{\partial y}{\partial b}\\).
* Because \\(y = a + b + c\\), we know that \\(\frac{\partial y}{\partial
b} = 1\\).
* Because \\(\frac{\partial y}{\partial b} = 1\\), we know that
\\(\frac{\partial z}{\partial b} = \frac{dz}{dy} \times 1 =
\frac{dz}{dy}\\).
It looks like \\(\frac{\partial z}{\partial b} = \frac{\partial z}{\partial
y}\\), as well.
At this point, we start to notice a pattern. We can apply the same logic
to see that \\(\frac{\partial z}{\partial c} = \frac{dz}{dy}\\).
4. *Write out the pattern*
Extrapolating the pattern, \\(\frac{\partial z}{\partial q}\\), where
\\(q\\) is *any* component, is always going to be a constant --
\\(\frac{dz}{dy}\\).
So in the end:
```haskell top hide
instance Backprop (R n) where
zero = zeroNum
add = addNum
one = oneNum
instance (KnownNat n, KnownNat m) => Backprop (L n m) where
zero = zeroNum
add = addNum
one = oneNum
sumElements :: KnownNat n => R n -> Double
sumElements = HU.sumElements . H.extract
```
```haskell top
liftedSumElements
:: (KnownNat n, Reifies s W)
=> BVar s (R n)
-> BVar s Double
liftedSumElements = liftOp1 . op1 $ \xs ->
( sumElements xs, \dzdy -> konst dzdy ) -- a constant vector
```
### Multiple-argument functions
Lifting multiple-argument functions is the same thing, except using `liftOp2`
and `op2`, or `liftOpN` and `opN`.
A `BVar s a -> BVar s b -> BVar s c` is, really, under the hood:
```haskell
type BVar s a -> BVar s b -> BVar s c =
a -> b -> (c, c -> (a, b))
```
That is, given an input `a` and `b`, you get:
1. A `c`, the result (the "forward pass")
2. A `c -> (a, b)`, the "scaled gradient" function returning the gradient of
both inputs.
The `c` parameter of the scaled gradient is again \\(\frac{dz}{dy}\\), and the
final `(a,b)` is a tuple of \\(\frac{\partial z}{\partial x_1}\\) and
\\(\frac{\partial z}{\partial x_2}\\): how \\(\frac{dz}{dy}\\) affects both of
the inputs.
For a simple example, let's look at \\(x + y\\). Working it out:
* \\(y = f(x_1, x_2) = x_1 + x_2\\)
* \\(z = g(f(x_1, x_2)) = g(x_1 + x_2)\\)
* Looking first for \\(\frac{\partial z}{\partial x_1}\\) in terms of
\\(\frac{dz}{dy}\\):
* \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial
y}{\partial x_1}\\) (chain rule)
* From \\(y = x_1 + x_2\\), we see that \\(\frac{\partial y}{\partial
x_1} = 1\\)
* Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \times 1
= \frac{dz}{dy}\\).
* Looking second for \\(\frac{\partial z}{\partial x_2}\\) in terms of
\\(\frac{dz}{dy}\\):
* \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \frac{\partial
y}{\partial x_2}\\) (chain rule)
* From \\(y = x_1 + x_2\\), we see that \\(\frac{\partial y}{\partial
x_2} = 1\\)
* Therefore, \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} \times 1
= \frac{dz}{dy}\\).
* Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy}\\), and also
\\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy}\\).
Putting it into code:
```haskell top
add :: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
-> BVar s a
add = liftOp2 . op2 $ \x1 x2 ->
( x1 + x2, \dzdy -> (dzdy, dzdy) )
```
Let's try our hand at multiplication, or \\(x * y\\):
* \\(y = f(x_1, x_2) = x_1 x_2\\)
* \\(z = g(f(x_1, x_2)) = g(x_1 x_2)\\)
* Looking first for \\(\frac{d\partial }{d\partial _1}\\) in terms of
\\(\frac{dz}{dy}\\):
* \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial
y}{\partial x_1}\\) (chain rule)
* From \\(y = x_1 x_2\\), we see that \\(\frac{\partial y}{\partial x_1}
= x_2\\)
* Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2\\).
* Looking second for \\(\frac{\partial z}{\partial x_2}\\) in terms of
\\(\frac{dz}{dy}\\):
* \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} \frac{\partial
y}{\partial x_1}\\) (chain rule)
* From \\(y = x_1 x_2\\), we see that \\(\frac{\partial y}{\partial x_2}
= x_1\\)
* Therefore, \\(\frac{\partial z}{\partial x_2} = \frac{dz}{dy} x_1\\).
* Therefore, \\(\frac{\partial z}{\partial x_1} = \frac{dz}{dy} x_2\\), and
\\(\frac{\partial z}{\partial x_2} = x_1 \frac{dz}{dy}\\).
In code:
```haskell top
mul :: (Num a, Backprop a, Reifies s W)
=> BVar s a
-> BVar s a
-> BVar s a
mul = liftOp2 . op2 $ \x1 x2 ->
( x1 * x2, \dzdy -> (dzdy * x2, x1 * dzdy) )
```
For non-trivial examples involving linear algebra, see the source for the *[hmatrix-backprop][]* library.
[hmatrix-backprop]: http://hackage.haskell.org/package/hmatrix-backprop
Some examples, for the dot product between two vectors and for matrix-vector
multiplication:
```haskell top
-- import qualified Numeric.LinearAlgebra.Static as H
-- | dot product between two vectors
dot
:: (KnownNat n, Reifies s W)
=> BVar s (R n)
-> BVar s (R n)
-> BVar s Double
dot = liftOp2 . op2 $ \u v ->
( u `H.dot` v
, \dzdy -> (H.konst dzdy * v, u * H.konst dzdy)
)
-- | matrix-vector multiplication
(#>)
:: (KnownNat m, KnownNat n, Reifies s W)
=> BVar s (L m n)
-> BVar s (R n)
-> BVar s (R m)
(#>) = liftOp2 . op2 $ \mat vec ->
( mat H.#> vec
, \dzdy -> (dzdy `H.outer` vec, H.tr mat H.#> dzdy)
)
```
Possibilities
-------------
That's it for this introductory tutorial on lifting single operations. More
information on the ways to apply these techniques to fully equip your library
for backpropagation (including arguments with multiple results, taking
advantage of isomorphisms, providing non-gradient functions) can be [found
here][equipping]!