{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK show-extensions #-}

-- | Module    :  InfBackprop.Common
-- Copyright   :  (C) 2023 Alexey Tochin
-- License     :  BSD3 (see the file LICENSE)
-- Maintainer  :  Alexey Tochin <Alexey.Tochin@gmail.com>
--
-- Provides base types and methods for backpropagation category morphism.
module InfBackprop.Common
  ( -- * Basic
    Backprop (MkBackprop),
    call,
    forward,
    backward,
    StartBackprop,
    startBackprop,
    forwardBackward,
    numba,
    numbaN,
    derivative,
    derivativeN,

    -- * Differentiable functions
    BackpropFunc,
    const,

    -- * Differentiable monadic functions
    pureBackprop,
  )
where

import Control.Arrow (Kleisli (Kleisli))
import Control.CatBifunctor (CatBiFunctor, first, (***))
import Control.Category (Category, id, (.), (>>>))
import GHC.Natural (Natural)
import IsomorphismClass (IsomorphicTo)
import IsomorphismClass.Extra ()
import IsomorphismClass.Isomorphism (Isomorphism, iso)
import NumHask (one, zero)
import NumHask.Algebra.Additive (Additive)
import NumHask.Algebra.Ring (Distributive)
import NumHask.Extra ()
import Prelude (Monad, flip, fromIntegral, iterate, pure, (!!), ($))
import qualified Prelude as P

-- | Backprop morphism.
-- #backprop#
-- Base type for an infinitely differentiable object.
-- It depends on categorical type @cat@ that is mostly common @(->)@,
-- see 'BackpropFunc' which by it's definition is equivalent to
--
-- @
-- data BackpropFunc input output = forall cache. MkBackpropFunc {
--  call     :: input -> output,
--  forward  :: BackpropFunc input (output, cache),
--  backward :: BackpropFunc (output, cache) input
-- }
-- @
--
-- The diagram below illustrates the how it works for the first derivative.
-- Consider the role of function @f@ in the derivative of the composition @g(f(h(...)))@.
-- #backprop_func#
--
-- @
--   h        ·                  f                   ·        g
--            ·                                      ·
--            ·               forward                ·
--            · --- input  >-----+-----> output >--- ·
--            ·                  V                   ·
--  ...       ·                  |                   ·       ...
--            ·                  | cache             ·
--            ·                  |                   ·
--            ·                  V                   ·
--            · --< dInput <-----+-----< dOutput <-- ·
--            ·               backward               ·
-- @
--
-- Notice that 'forward' and 'backward' are of type 'BackpropFunc' but not @(->)@.
-- This is needed for further differentiation.
-- However for the first derivative this difference can be ignored.
--
-- The return type of 'forward' contains additional term @cache@.
-- It is needed to save and transfer data calculated in the forward step to the backward step for reuse.
-- See an example in
--
-- [Differentiation with logging](#differentiation_with_logging)
-- section .
--
-- == __Remark__
-- Mathematically speaking we have to distinguish the types for 'forward' and for 'backward' methods because the second
-- acts on the cotangent bundle.
-- However, for simplicity and due to technical reasons we identify the types @input@ and @dInput@
-- as well as @output@ and @dOutput@ which is enough for our purposes because these types are usually real numbers
-- or arrays of real numbers.
data Backprop cat input output = forall cache.
  MkBackprop
  { -- | Simple internal category object extraction.
    Backprop cat input output -> cat input output
call :: cat input output,
    -- | Returns forward category.
    -- In the case @cat = (->)@, the method coincides with 'Backprop'@ cat input output@ itself
    -- but the output contains an additional data term @cache@ with some calculation result that can be reused on in
    -- 'backward'.
    ()
forward :: Backprop cat input (output, cache),
    -- | Returns backward category. In the case @cat = (->)@, the method takes the additional data term @cache@ that is
    -- calculated in 'forward'.
    ()
backward :: Backprop cat (output, cache) input
  }

composition' ::
  forall cat x y z.
  (Isomorphism cat, CatBiFunctor (,) cat) =>
  Backprop cat x y ->
  Backprop cat y z ->
  Backprop cat x z
composition' :: Backprop cat x y -> Backprop cat y z -> Backprop cat x z
composition'
  (MkBackprop cat x y
callF (Backprop cat x (y, cache)
forwardF :: Backprop cat x (y, hF)) (Backprop cat (y, cache) x
backwardF :: Backprop cat (y, hF) x))
  (MkBackprop cat y z
callG (Backprop cat y (z, cache)
forwardG :: Backprop cat y (z, hG)) (Backprop cat (z, cache) y
backwardG :: Backprop cat (z, hG) y)) =
    cat x z
-> Backprop cat x (z, (cache, cache))
-> Backprop cat (z, (cache, cache)) x
-> Backprop cat x z
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat x z
call_ Backprop cat x (z, (cache, cache))
forward_ Backprop cat (z, (cache, cache)) x
backward_
    where
      call_ :: cat x z
      call_ :: cat x z
call_ = cat x y
callF cat x y -> cat y z -> cat x z
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> cat y z
callG

      forward_ :: Backprop cat x (z, (hG, hF))
      forward_ :: Backprop cat x (z, (cache, cache))
forward_ =
        (Backprop cat x (y, cache)
forwardF Backprop cat x (y, cache)
-> Backprop cat (y, cache) ((z, cache), cache)
-> Backprop cat x ((z, cache), cache)
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat y (z, cache)
-> Backprop cat (y, cache) ((z, cache), cache)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat y (z, cache)
forwardG) Backprop cat x ((z, cache), cache)
-> Backprop cat ((z, cache), cache) (z, (cache, cache))
-> Backprop cat x (z, (cache, cache))
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (Backprop cat ((z, cache), cache) (z, (cache, cache))
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((z, hG), hF) (z, (hG, hF)))

      backward_ :: Backprop cat (z, (hG, hF)) x
      backward_ :: Backprop cat (z, (cache, cache)) x
backward_ =
        (Backprop cat (z, (cache, cache)) ((z, cache), cache)
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat (z, (hG, hF)) ((z, hG), hF)) Backprop cat (z, (cache, cache)) ((z, cache), cache)
-> Backprop cat ((z, cache), cache) (y, cache)
-> Backprop cat (z, (cache, cache)) (y, cache)
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat (z, cache) y
-> Backprop cat ((z, cache), cache) (y, cache)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat (z, cache) y
backwardG Backprop cat (z, (cache, cache)) (y, cache)
-> Backprop cat (y, cache) x -> Backprop cat (z, (cache, cache)) x
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat (y, cache) x
backwardF

iso' ::
  forall cat x y.
  (IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
  Backprop cat x y
iso' :: Backprop cat x y
iso' = cat x y
-> Backprop cat x (y, ())
-> Backprop cat (y, ()) x
-> Backprop cat x y
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat x y
call_ (Backprop cat x (y, ())
forward_ :: Backprop cat x (y, ())) (Backprop cat (y, ()) x
backward_ :: Backprop cat (y, ()) x)
  where
    call_ :: cat x y
    call_ :: cat x y
call_ = cat x y
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso

    forward_ :: Backprop cat x (y, ())
    forward_ :: Backprop cat x (y, ())
forward_ = (Backprop cat x y
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat x y) Backprop cat x y
-> Backprop cat y (y, ()) -> Backprop cat x (y, ())
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (Backprop cat y (y, ())
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat y (y, ()))

    backward_ :: Backprop cat (y, ()) x
    backward_ :: Backprop cat (y, ()) x
backward_ = (Backprop cat (y, ()) y
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat (y, ()) y) Backprop cat (y, ()) y
-> Backprop cat y x -> Backprop cat (y, ()) x
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (Backprop cat y x
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat y x)

instance
  (Isomorphism cat, CatBiFunctor (,) cat) =>
  Category (Backprop cat)
  where
  id :: Backprop cat a a
id = Backprop cat a a
forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso'
  . :: Backprop cat b c -> Backprop cat a b -> Backprop cat a c
(.) = (Backprop cat a b -> Backprop cat b c -> Backprop cat a c)
-> Backprop cat b c -> Backprop cat a b -> Backprop cat a c
forall a b c. (a -> b -> c) -> b -> a -> c
flip Backprop cat a b -> Backprop cat b c -> Backprop cat a c
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
composition'

instance
  (Isomorphism cat, CatBiFunctor (,) cat) =>
  Isomorphism (Backprop cat)
  where
  iso :: Backprop cat a b
iso = Backprop cat a b
forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso'

instance
  (Isomorphism cat, CatBiFunctor (,) cat) =>
  CatBiFunctor (,) (Backprop cat)
  where
  *** :: Backprop cat a1 b1
-> Backprop cat a2 b2 -> Backprop cat (a1, a2) (b1, b2)
(***)
    (MkBackprop cat a1 b1
call1 (Backprop cat a1 (b1, cache)
forward1 :: Backprop cat x1 (y1, h1)) (Backprop cat (b1, cache) a1
backward1 :: Backprop cat (y1, h1) x1))
    (MkBackprop cat a2 b2
call2 (Backprop cat a2 (b2, cache)
forward2 :: Backprop cat x2 (y2, h2)) (Backprop cat (b2, cache) a2
backward2 :: Backprop cat (y2, h2) x2)) =
      cat (a1, a2) (b1, b2)
-> Backprop cat (a1, a2) ((b1, b2), (cache, cache))
-> Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
-> Backprop cat (a1, a2) (b1, b2)
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat (a1, a2) (b1, b2)
call12 Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forward12 Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
backward12
      where
        call12 :: cat (x1, x2) (y1, y2)
        call12 :: cat (a1, a2) (b1, b2)
call12 = cat a1 b1
call1 cat a1 b1 -> cat a2 b2 -> cat (a1, a2) (b1, b2)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** cat a2 b2
call2

        forward12 :: Backprop cat (x1, x2) ((y1, y2), (h1, h2))
        forward12 :: Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forward12 = Backprop cat a1 (b1, cache)
forward1 Backprop cat a1 (b1, cache)
-> Backprop cat a2 (b2, cache)
-> Backprop cat (a1, a2) ((b1, cache), (b2, cache))
forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** Backprop cat a2 (b2, cache)
forward2 Backprop cat (a1, a2) ((b1, cache), (b2, cache))
-> Backprop
     cat ((b1, cache), (b2, cache)) ((b1, b2), (cache, cache))
-> Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (Backprop cat ((b1, cache), (b2, cache)) ((b1, b2), (cache, cache))
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((y1, h1), (y2, h2)) ((y1, y2), (h1, h2)))

        backward12 :: Backprop cat ((y1, y2), (h1, h2)) (x1, x2)
        backward12 :: Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
backward12 = (Backprop cat ((b1, b2), (cache, cache)) ((b1, cache), (b2, cache))
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((y1, y2), (h1, h2)) ((y1, h1), (y2, h2))) Backprop cat ((b1, b2), (cache, cache)) ((b1, cache), (b2, cache))
-> Backprop cat ((b1, cache), (b2, cache)) (a1, a2)
-> Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat (b1, cache) a1
backward1 Backprop cat (b1, cache) a1
-> Backprop cat (b2, cache) a2
-> Backprop cat ((b1, cache), (b2, cache)) (a1, a2)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** Backprop cat (b2, cache) a2
backward2

-- | Implementation of the process illustrated in the
-- [diagram](#backprop_func).
-- The first argument is a backprop morphism @y -> dy@
-- The second argument is a backprop morphism @x -> y@
-- The output is the backprop @x -> dx@ build according the
-- [diagram](#backprop_func)
forwardBackward ::
  (Isomorphism cat, CatBiFunctor (,) cat) =>
  -- | backprop morphism between @y@ and @dy@
  -- that is inferred after the forward step for @f@ and before the backward step for @f@
  Backprop cat y y ->
  -- | some backprop morphism @f@ between @x@ and @y@
  Backprop cat x y ->
  -- | the output backprop morphism from @x@ to @dx@ that is the composition.
  Backprop cat x x
forwardBackward :: Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forwardBackward Backprop cat y y
dy (MkBackprop cat x y
_ Backprop cat x (y, cache)
forward_ Backprop cat (y, cache) x
backward_) = Backprop cat x (y, cache)
forward_ Backprop cat x (y, cache)
-> Backprop cat (y, cache) x -> Backprop cat x x
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat y y -> Backprop cat (y, cache) (y, cache)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat y y
dy Backprop cat (y, cache) (y, cache)
-> Backprop cat (y, cache) x -> Backprop cat (y, cache) x
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat (y, cache) x
backward_

-- | Interface for categories @cat@ and value types @x@ that support starting the backpropagation.
-- For example for @(->)@ and @Float@ we are able to start the backpropagation like
-- @f(g(x))@ -> @1 · f'(g(x)) · ...@
-- because @f@ is a @Float@ valued (scalar) function.
-- Calculating Jacobians is not currently implemented.
class Distributive x => StartBackprop cat x where
  -- | Morphism that connects forward and backward chain.
  -- Usually it is just @1@ that is supposed to be multiplied on the derivative of the top function.
  startBackprop :: Backprop cat x x

-- | Backpropagation derivative in terms of backprop morphisms.
numba ::
  (Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
  Backprop cat x y ->
  Backprop cat x x
numba :: Backprop cat x y -> Backprop cat x x
numba = Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forwardBackward Backprop cat y y
forall (cat :: * -> * -> *) x.
StartBackprop cat x =>
Backprop cat x x
startBackprop

-- | Backpropagation ns derivative in terms of backprop morphisms.
numbaN ::
  (Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
  Natural ->
  Backprop cat x x ->
  Backprop cat x x
numbaN :: Natural -> Backprop cat x x -> Backprop cat x x
numbaN Natural
n Backprop cat x x
f = (Backprop cat x x -> Backprop cat x x)
-> Backprop cat x x -> [Backprop cat x x]
forall a. (a -> a) -> a -> [a]
iterate Backprop cat x x -> Backprop cat x x
forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba Backprop cat x x
f [Backprop cat x x] -> Int -> Backprop cat x x
forall a. [a] -> Int -> a
!! Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n

-- | Backpropagation derivative as categorical object.
-- If @cat@ is @(->)@ the output is simply a function.
--
-- ==== __Examples of usage__
--
-- >>> import InfBackprop (sin)
-- >>> import Prelude (Float)
-- >>> derivative sin (0 :: Float)
-- 1.0
derivative ::
  (Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
  Backprop cat x y ->
  cat x x
derivative :: Backprop cat x y -> cat x x
derivative = Backprop cat x x -> cat x x
forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call (Backprop cat x x -> cat x x)
-> (Backprop cat x y -> Backprop cat x x)
-> Backprop cat x y
-> cat x x
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Backprop cat x y -> Backprop cat x x
forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba

-- | Backpropagation derivative of order n as categorical object.
-- If @cat@ is @(->)@ the output is simply a function.
--
-- ==== __Examples of usage__
--
-- >>> import InfBackprop (pow, const)
-- >>> import Prelude (Float, fmap)
-- >>> myFunc = (pow 2) :: Backprop (->) Float Float
--
-- >>> fmap (derivativeN 0 myFunc) [-3, -2, -1, 0, 1, 2, 3]
-- [9.0,4.0,1.0,0.0,1.0,4.0,9.0]
--
-- >>> fmap (derivativeN 1 myFunc) [-3, -2, -1, 0, 1, 2, 3]
-- [-6.0,-4.0,-2.0,0.0,2.0,4.0,6.0]
--
-- >>> fmap (derivativeN 2 myFunc) [-3, -2, -1, 0, 1, 2, 3]
-- [2.0,2.0,2.0,2.0,2.0,2.0,2.0]
--
-- >>> fmap (derivativeN 3 myFunc) [-3, -2, -1, 0, 1, 2, 3]
-- [0.0,0.0,0.0,0.0,0.0,0.0,0.0]
derivativeN ::
  (Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
  Natural ->
  Backprop cat x x ->
  cat x x
derivativeN :: Natural -> Backprop cat x x -> cat x x
derivativeN Natural
n = Backprop cat x x -> cat x x
forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call (Backprop cat x x -> cat x x)
-> (Backprop cat x x -> Backprop cat x x)
-> Backprop cat x x
-> cat x x
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Natural -> Backprop cat x x -> Backprop cat x x
forall (cat :: * -> * -> *) x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural -> Backprop cat x x -> Backprop cat x x
numbaN Natural
n

-- | Infinitely differentiable function.
-- The definition of the type synonym is equivalent to
--
-- @
-- data BackpropFunc input output = forall cache. MkBackpropFunc {
--    call     :: input -> output,
--    forward  :: BackpropFunc input (output, cache),
--    backward :: BackpropFunc (output, cache) input
--  }
-- @
--
-- See 'Backprop' for details.
--
-- ==== __Examples of usage__
--
-- >>> import Prelude (fmap, Float)
-- >>> import InfBackprop (pow, call, derivative)
-- >>> myFunc = pow 2 :: BackpropFunc Float Float
-- >>> f = call myFunc :: Float -> Float
-- >>> fmap f [-3, -2, -1, 0, 1, 2, 3]
-- [9.0,4.0,1.0,0.0,1.0,4.0,9.0]
-- >>> df = derivative myFunc :: Float -> Float
-- >>> fmap df [-3, -2, -1, 0, 1, 2, 3]
-- [-6.0,-4.0,-2.0,0.0,2.0,4.0,6.0]
type BackpropFunc = Backprop (->)

instance forall x. (Distributive x) => StartBackprop (->) x where
  startBackprop :: Backprop (->) x x
startBackprop = x -> Backprop (->) x x
forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const x
forall a. Multiplicative a => a
one

-- | Infinitely differentiable constant function.
--
-- === __Examples of usage__
--
-- >>> import Prelude (Float)
-- >>> import InfBackprop (call, derivative, derivativeN)
--
-- >>> call (const 5) ()
-- 5
--
-- >>> derivative (const (5 :: Float)) 42
-- 0
--
-- >>> derivativeN 2 (const (5 :: Float)) 42
-- 0.0
const ::
  forall c x.
  (Additive c, Additive x) =>
  c ->
  BackpropFunc x c
const :: c -> BackpropFunc x c
const c
c = (x -> c)
-> Backprop (->) x (c, ())
-> Backprop (->) (c, ()) x
-> BackpropFunc x c
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop x -> c
call' Backprop (->) x (c, ())
forward' Backprop (->) (c, ()) x
backward'
  where
    call' :: x -> c
    call' :: x -> c
call' = c -> x -> c
forall a b. a -> b -> a
P.const c
c
    forward' :: BackpropFunc x (c, ())
    forward' :: Backprop (->) x (c, ())
forward' = c -> BackpropFunc x c
forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const c
c BackpropFunc x c
-> Backprop (->) c (c, ()) -> Backprop (->) x (c, ())
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (Backprop (->) c (c, ())
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: BackpropFunc c (c, ()))
    backward' :: BackpropFunc (c, ()) x
    backward' :: Backprop (->) (c, ()) x
backward' = (Backprop (->) (c, ()) c
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: BackpropFunc (c, ()) c) Backprop (->) (c, ()) c
-> Backprop (->) c x -> Backprop (->) (c, ()) x
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> x -> Backprop (->) c x
forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const x
forall a. Additive a => a
zero

-- | Lifts a backprop function morphism to the corresponding pure Kleisli morphism.
pureBackprop :: forall a b m. Monad m => Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop :: Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop
  ( MkBackprop
      (a -> b
call'' :: a -> b)
      (Backprop (->) a (b, cache)
forward'' :: Backprop (->) a (b, c))
      (Backprop (->) (b, cache) a
backward'' :: Backprop (->) (b, c) a)
    ) =
    Kleisli m a b
-> Backprop (Kleisli m) a (b, cache)
-> Backprop (Kleisli m) (b, cache) a
-> Backprop (Kleisli m) a b
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop Kleisli m a b
call' Backprop (Kleisli m) a (b, cache)
forward' Backprop (Kleisli m) (b, cache) a
backward'
    where
      call' :: Kleisli m a b
      call' :: Kleisli m a b
call' = (a -> m b) -> Kleisli m a b
forall (m :: * -> *) a b. (a -> m b) -> Kleisli m a b
Kleisli ((a -> m b) -> Kleisli m a b) -> (a -> m b) -> Kleisli m a b
forall a b. (a -> b) -> a -> b
$ b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> m b) -> (a -> b) -> a -> m b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
call''

      forward' :: Backprop (Kleisli m) a (b, c)
      forward' :: Backprop (Kleisli m) a (b, cache)
forward' = Backprop (->) a (b, cache) -> Backprop (Kleisli m) a (b, cache)
forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) a (b, cache)
forward''

      backward' :: Backprop (Kleisli m) (b, c) a
      backward' :: Backprop (Kleisli m) (b, cache) a
backward' = Backprop (->) (b, cache) a -> Backprop (Kleisli m) (b, cache) a
forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) (b, cache) a
backward''

instance (Distributive x, Monad m) => StartBackprop (Kleisli m) x where
  startBackprop :: Backprop (Kleisli m) x x
startBackprop = Backprop (->) x x -> Backprop (Kleisli m) x x
forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) x x
forall (cat :: * -> * -> *) x.
StartBackprop cat x =>
Backprop cat x x
startBackprop