{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module InfBackprop.Common
(
Backprop (MkBackprop),
call,
forward,
backward,
StartBackprop,
startBackprop,
forwardBackward,
numba,
numbaN,
derivative,
derivativeN,
BackpropFunc,
const,
pureBackprop,
)
where
import Control.Arrow (Kleisli (Kleisli))
import Control.CatBifunctor (CatBiFunctor, first, (***))
import Control.Category (Category, id, (.), (>>>))
import GHC.Natural (Natural)
import IsomorphismClass (IsomorphicTo)
import IsomorphismClass.Extra ()
import IsomorphismClass.Isomorphism (Isomorphism, iso)
import NumHask (one, zero)
import NumHask.Algebra.Additive (Additive)
import NumHask.Algebra.Ring (Distributive)
import NumHask.Extra ()
import Prelude (Monad, flip, fromIntegral, iterate, pure, (!!), ($))
import qualified Prelude as P
data Backprop cat input output = forall cache.
MkBackprop
{
Backprop cat input output -> cat input output
call :: cat input output,
()
forward :: Backprop cat input (output, cache),
()
backward :: Backprop cat (output, cache) input
}
composition' ::
forall cat x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y ->
Backprop cat y z ->
Backprop cat x z
composition' :: Backprop cat x y -> Backprop cat y z -> Backprop cat x z
composition'
(MkBackprop cat x y
callF (Backprop cat x (y, cache)
forwardF :: Backprop cat x (y, hF)) (Backprop cat (y, cache) x
backwardF :: Backprop cat (y, hF) x))
(MkBackprop cat y z
callG (Backprop cat y (z, cache)
forwardG :: Backprop cat y (z, hG)) (Backprop cat (z, cache) y
backwardG :: Backprop cat (z, hG) y)) =
cat x z
-> Backprop cat x (z, (cache, cache))
-> Backprop cat (z, (cache, cache)) x
-> Backprop cat x z
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat x z
call_ Backprop cat x (z, (cache, cache))
forward_ Backprop cat (z, (cache, cache)) x
backward_
where
call_ :: cat x z
call_ :: cat x z
call_ = cat x y
callF cat x y -> cat y z -> cat x z
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> cat y z
callG
forward_ :: Backprop cat x (z, (hG, hF))
forward_ :: Backprop cat x (z, (cache, cache))
forward_ =
(Backprop cat x (y, cache)
forwardF Backprop cat x (y, cache)
-> Backprop cat (y, cache) ((z, cache), cache)
-> Backprop cat x ((z, cache), cache)
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat y (z, cache)
-> Backprop cat (y, cache) ((z, cache), cache)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat y (z, cache)
forwardG) Backprop cat x ((z, cache), cache)
-> Backprop cat ((z, cache), cache) (z, (cache, cache))
-> Backprop cat x (z, (cache, cache))
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (Backprop cat ((z, cache), cache) (z, (cache, cache))
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((z, hG), hF) (z, (hG, hF)))
backward_ :: Backprop cat (z, (hG, hF)) x
backward_ :: Backprop cat (z, (cache, cache)) x
backward_ =
(Backprop cat (z, (cache, cache)) ((z, cache), cache)
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat (z, (hG, hF)) ((z, hG), hF)) Backprop cat (z, (cache, cache)) ((z, cache), cache)
-> Backprop cat ((z, cache), cache) (y, cache)
-> Backprop cat (z, (cache, cache)) (y, cache)
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat (z, cache) y
-> Backprop cat ((z, cache), cache) (y, cache)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat (z, cache) y
backwardG Backprop cat (z, (cache, cache)) (y, cache)
-> Backprop cat (y, cache) x -> Backprop cat (z, (cache, cache)) x
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` Backprop cat (y, cache) x
backwardF
iso' ::
forall cat x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso' :: Backprop cat x y
iso' = cat x y
-> Backprop cat x (y, ())
-> Backprop cat (y, ()) x
-> Backprop cat x y
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat x y
call_ (Backprop cat x (y, ())
forward_ :: Backprop cat x (y, ())) (Backprop cat (y, ()) x
backward_ :: Backprop cat (y, ()) x)
where
call_ :: cat x y
call_ :: cat x y
call_ = cat x y
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso
forward_ :: Backprop cat x (y, ())
forward_ :: Backprop cat x (y, ())
forward_ = (Backprop cat x y
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat x y) Backprop cat x y
-> Backprop cat y (y, ()) -> Backprop cat x (y, ())
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (Backprop cat y (y, ())
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat y (y, ()))
backward_ :: Backprop cat (y, ()) x
backward_ :: Backprop cat (y, ()) x
backward_ = (Backprop cat (y, ()) y
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat (y, ()) y) Backprop cat (y, ()) y
-> Backprop cat y x -> Backprop cat (y, ()) x
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
`composition'` (Backprop cat y x
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat y x)
instance
(Isomorphism cat, CatBiFunctor (,) cat) =>
Category (Backprop cat)
where
id :: Backprop cat a a
id = Backprop cat a a
forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso'
. :: Backprop cat b c -> Backprop cat a b -> Backprop cat a c
(.) = (Backprop cat a b -> Backprop cat b c -> Backprop cat a c)
-> Backprop cat b c -> Backprop cat a b -> Backprop cat a c
forall a b c. (a -> b -> c) -> b -> a -> c
flip Backprop cat a b -> Backprop cat b c -> Backprop cat a c
forall (cat :: * -> * -> *) x y z.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y -> Backprop cat y z -> Backprop cat x z
composition'
instance
(Isomorphism cat, CatBiFunctor (,) cat) =>
Isomorphism (Backprop cat)
where
iso :: Backprop cat a b
iso = Backprop cat a b
forall (cat :: * -> * -> *) x y.
(IsomorphicTo x y, Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat x y
iso'
instance
(Isomorphism cat, CatBiFunctor (,) cat) =>
CatBiFunctor (,) (Backprop cat)
where
*** :: Backprop cat a1 b1
-> Backprop cat a2 b2 -> Backprop cat (a1, a2) (b1, b2)
(***)
(MkBackprop cat a1 b1
call1 (Backprop cat a1 (b1, cache)
forward1 :: Backprop cat x1 (y1, h1)) (Backprop cat (b1, cache) a1
backward1 :: Backprop cat (y1, h1) x1))
(MkBackprop cat a2 b2
call2 (Backprop cat a2 (b2, cache)
forward2 :: Backprop cat x2 (y2, h2)) (Backprop cat (b2, cache) a2
backward2 :: Backprop cat (y2, h2) x2)) =
cat (a1, a2) (b1, b2)
-> Backprop cat (a1, a2) ((b1, b2), (cache, cache))
-> Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
-> Backprop cat (a1, a2) (b1, b2)
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop cat (a1, a2) (b1, b2)
call12 Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forward12 Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
backward12
where
call12 :: cat (x1, x2) (y1, y2)
call12 :: cat (a1, a2) (b1, b2)
call12 = cat a1 b1
call1 cat a1 b1 -> cat a2 b2 -> cat (a1, a2) (b1, b2)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** cat a2 b2
call2
forward12 :: Backprop cat (x1, x2) ((y1, y2), (h1, h2))
forward12 :: Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forward12 = Backprop cat a1 (b1, cache)
forward1 Backprop cat a1 (b1, cache)
-> Backprop cat a2 (b2, cache)
-> Backprop cat (a1, a2) ((b1, cache), (b2, cache))
forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** Backprop cat a2 (b2, cache)
forward2 Backprop cat (a1, a2) ((b1, cache), (b2, cache))
-> Backprop
cat ((b1, cache), (b2, cache)) ((b1, b2), (cache, cache))
-> Backprop cat (a1, a2) ((b1, b2), (cache, cache))
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (Backprop cat ((b1, cache), (b2, cache)) ((b1, b2), (cache, cache))
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((y1, h1), (y2, h2)) ((y1, y2), (h1, h2)))
backward12 :: Backprop cat ((y1, y2), (h1, h2)) (x1, x2)
backward12 :: Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
backward12 = (Backprop cat ((b1, b2), (cache, cache)) ((b1, cache), (b2, cache))
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: Backprop cat ((y1, y2), (h1, h2)) ((y1, h1), (y2, h2))) Backprop cat ((b1, b2), (cache, cache)) ((b1, cache), (b2, cache))
-> Backprop cat ((b1, cache), (b2, cache)) (a1, a2)
-> Backprop cat ((b1, b2), (cache, cache)) (a1, a2)
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat (b1, cache) a1
backward1 Backprop cat (b1, cache) a1
-> Backprop cat (b2, cache) a2
-> Backprop cat ((b1, cache), (b2, cache)) (a1, a2)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a1 b1 a2 b2.
CatBiFunctor p cat =>
cat a1 b1 -> cat a2 b2 -> cat (p a1 a2) (p b1 b2)
*** Backprop cat (b2, cache) a2
backward2
forwardBackward ::
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat y y ->
Backprop cat x y ->
Backprop cat x x
forwardBackward :: Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forwardBackward Backprop cat y y
dy (MkBackprop cat x y
_ Backprop cat x (y, cache)
forward_ Backprop cat (y, cache) x
backward_) = Backprop cat x (y, cache)
forward_ Backprop cat x (y, cache)
-> Backprop cat (y, cache) x -> Backprop cat x x
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat y y -> Backprop cat (y, cache) (y, cache)
forall (p :: * -> * -> *) (cat :: * -> * -> *) a b c.
CatBiFunctor p cat =>
cat a b -> cat (p a c) (p b c)
first Backprop cat y y
dy Backprop cat (y, cache) (y, cache)
-> Backprop cat (y, cache) x -> Backprop cat (y, cache) x
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Backprop cat (y, cache) x
backward_
class Distributive x => StartBackprop cat x where
startBackprop :: Backprop cat x x
numba ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y ->
Backprop cat x x
numba :: Backprop cat x y -> Backprop cat x x
numba = Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat) =>
Backprop cat y y -> Backprop cat x y -> Backprop cat x x
forwardBackward Backprop cat y y
forall (cat :: * -> * -> *) x.
StartBackprop cat x =>
Backprop cat x x
startBackprop
numbaN ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural ->
Backprop cat x x ->
Backprop cat x x
numbaN :: Natural -> Backprop cat x x -> Backprop cat x x
numbaN Natural
n Backprop cat x x
f = (Backprop cat x x -> Backprop cat x x)
-> Backprop cat x x -> [Backprop cat x x]
forall a. (a -> a) -> a -> [a]
iterate Backprop cat x x -> Backprop cat x x
forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba Backprop cat x x
f [Backprop cat x x] -> Int -> Backprop cat x x
forall a. [a] -> Int -> a
!! Natural -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
n
derivative ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y ->
cat x x
derivative :: Backprop cat x y -> cat x x
derivative = Backprop cat x x -> cat x x
forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call (Backprop cat x x -> cat x x)
-> (Backprop cat x y -> Backprop cat x x)
-> Backprop cat x y
-> cat x x
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Backprop cat x y -> Backprop cat x x
forall (cat :: * -> * -> *) y x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat y) =>
Backprop cat x y -> Backprop cat x x
numba
derivativeN ::
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural ->
Backprop cat x x ->
cat x x
derivativeN :: Natural -> Backprop cat x x -> cat x x
derivativeN Natural
n = Backprop cat x x -> cat x x
forall (cat :: * -> * -> *) input output.
Backprop cat input output -> cat input output
call (Backprop cat x x -> cat x x)
-> (Backprop cat x x -> Backprop cat x x)
-> Backprop cat x x
-> cat x x
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Natural -> Backprop cat x x -> Backprop cat x x
forall (cat :: * -> * -> *) x.
(Isomorphism cat, CatBiFunctor (,) cat, StartBackprop cat x) =>
Natural -> Backprop cat x x -> Backprop cat x x
numbaN Natural
n
type BackpropFunc = Backprop (->)
instance forall x. (Distributive x) => StartBackprop (->) x where
startBackprop :: Backprop (->) x x
startBackprop = x -> Backprop (->) x x
forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const x
forall a. Multiplicative a => a
one
const ::
forall c x.
(Additive c, Additive x) =>
c ->
BackpropFunc x c
const :: c -> BackpropFunc x c
const c
c = (x -> c)
-> Backprop (->) x (c, ())
-> Backprop (->) (c, ()) x
-> BackpropFunc x c
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop x -> c
call' Backprop (->) x (c, ())
forward' Backprop (->) (c, ()) x
backward'
where
call' :: x -> c
call' :: x -> c
call' = c -> x -> c
forall a b. a -> b -> a
P.const c
c
forward' :: BackpropFunc x (c, ())
forward' :: Backprop (->) x (c, ())
forward' = c -> BackpropFunc x c
forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const c
c BackpropFunc x c
-> Backprop (->) c (c, ()) -> Backprop (->) x (c, ())
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> (Backprop (->) c (c, ())
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: BackpropFunc c (c, ()))
backward' :: BackpropFunc (c, ()) x
backward' :: Backprop (->) (c, ()) x
backward' = (Backprop (->) (c, ()) c
forall (c :: * -> * -> *) a b.
(Isomorphism c, IsomorphicTo a b) =>
c a b
iso :: BackpropFunc (c, ()) c) Backprop (->) (c, ()) c
-> Backprop (->) c x -> Backprop (->) (c, ()) x
forall k (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> x -> Backprop (->) c x
forall c x. (Additive c, Additive x) => c -> BackpropFunc x c
const x
forall a. Additive a => a
zero
pureBackprop :: forall a b m. Monad m => Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop :: Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop
( MkBackprop
(a -> b
call'' :: a -> b)
(Backprop (->) a (b, cache)
forward'' :: Backprop (->) a (b, c))
(Backprop (->) (b, cache) a
backward'' :: Backprop (->) (b, c) a)
) =
Kleisli m a b
-> Backprop (Kleisli m) a (b, cache)
-> Backprop (Kleisli m) (b, cache) a
-> Backprop (Kleisli m) a b
forall (cat :: * -> * -> *) input output cache.
cat input output
-> Backprop cat input (output, cache)
-> Backprop cat (output, cache) input
-> Backprop cat input output
MkBackprop Kleisli m a b
call' Backprop (Kleisli m) a (b, cache)
forward' Backprop (Kleisli m) (b, cache) a
backward'
where
call' :: Kleisli m a b
call' :: Kleisli m a b
call' = (a -> m b) -> Kleisli m a b
forall (m :: * -> *) a b. (a -> m b) -> Kleisli m a b
Kleisli ((a -> m b) -> Kleisli m a b) -> (a -> m b) -> Kleisli m a b
forall a b. (a -> b) -> a -> b
$ b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure (b -> m b) -> (a -> b) -> a -> m b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
call''
forward' :: Backprop (Kleisli m) a (b, c)
forward' :: Backprop (Kleisli m) a (b, cache)
forward' = Backprop (->) a (b, cache) -> Backprop (Kleisli m) a (b, cache)
forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) a (b, cache)
forward''
backward' :: Backprop (Kleisli m) (b, c) a
backward' :: Backprop (Kleisli m) (b, cache) a
backward' = Backprop (->) (b, cache) a -> Backprop (Kleisli m) (b, cache) a
forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) (b, cache) a
backward''
instance (Distributive x, Monad m) => StartBackprop (Kleisli m) x where
startBackprop :: Backprop (Kleisli m) x x
startBackprop = Backprop (->) x x -> Backprop (Kleisli m) x x
forall a b (m :: * -> *).
Monad m =>
Backprop (->) a b -> Backprop (Kleisli m) a b
pureBackprop Backprop (->) x x
forall (cat :: * -> * -> *) x.
StartBackprop cat x =>
Backprop cat x x
startBackprop