{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- | -- Module : Numeric.Backprop.Class -- Copyright : (c) Justin Le 2018 -- License : BSD3 -- -- Maintainer : justin@jle.im -- Stability : experimental -- Portability : non-portable -- -- Provides the 'Backprop' typeclass, a class for values that can be used -- for backpropagation. -- -- This class replaces the old (version 0.1) API relying on 'Num'. -- -- @since 0.2.0.0 module Numeric.Backprop.Class ( -- * Backpropagatable types Backprop(..) -- * Derived methods , zeroNum, addNum, oneNum , zeroVec, addVec, oneVec , zeroFunctor, addIsList, addAsList, oneFunctor , genericZero, genericAdd, genericOne -- * Generics , GZero(..), GAdd(..), GOne(..) ) where import Data.Complex import Data.Foldable hiding (toList) import Data.Functor.Identity import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Data.Ratio import Data.Type.Combinator hiding ((:.:), Comp1) import Data.Type.Option import Data.Type.Product hiding (toList) import Data.Void import GHC.Exts import GHC.Generics import Type.Family.List import qualified Data.IntMap as IM import qualified Data.Map as M import qualified Data.Sequence as Seq import qualified Data.Vector as V import qualified Data.Vector.Generic as VG import qualified Data.Vector.Primitive as VP import qualified Data.Vector.Storable as VS import qualified Data.Vector.Unboxed as VU import qualified Type.Family.Maybe as M -- | Class of values that can be backpropagated in general. -- -- For instances of 'Num', these methods can be given by 'zeroNum', -- 'addNum', and 'oneNum'. There are also generic options given in -- "Numeric.Backprop.Class" for functors, 'IsList' instances, and 'Generic' -- instances. -- -- @ -- instance 'Backprop' 'Double' where -- 'zero' = 'zeroNum' -- 'add' = 'addNum' -- 'one' = 'oneNum' -- @ -- -- If you leave the body of an instance declaration blank, GHC Generics -- will be used to derive instances if the type has a single constructor -- and each field is an instance of 'Backprop'. -- -- To ensure that backpropagation works in a sound way, should obey the -- laws: -- -- [/identity/] -- -- * @'add' x ('zero' y) = x@ -- -- * @'add' ('zero' x) y = y@ -- -- Also implies preservation of information, making @'zipWith' ('+')@ an -- illegal implementation for lists and vectors. -- -- This is only expected to be true up to potential "extra zeroes" in @x@ -- and @y@ in the result. -- -- [/commutativity/] -- -- * @'add' x y = 'add' y x@ -- -- [/associativity/] -- -- * @'add' x ('add' y z) = 'add' ('add' x y) z@ -- -- [/idempotence/] -- -- * @'zero' '.' 'zero' = 'zero'@ -- -- * @'one' '.' 'one' = 'one'@ -- -- Note that not all values in the backpropagation process needs all of -- these methods: Only the "final result" needs 'one', for example. These -- are all grouped under one typeclass for convenience in defining -- instances, and also to talk about sensible laws. For fine-grained -- control, use the "explicit" versions of library functions (for example, -- in "Numeric.Backprop.Explicit") instead of 'Backprop' based ones. -- -- This typeclass replaces the reliance on 'Num' of the previous API -- (v0.1). 'Num' is strictly more powerful than 'Backprop', and is -- a stronger constraint on types than is necessary for proper -- backpropagating. In particular, 'fromInteger' is a problem for many -- types, preventing useful backpropagation for lists, variable-length -- vectors (like "Data.Vector") and variable-size matrices from linear -- algebra libraries like /hmatrix/ and /accelerate/. -- -- @since 0.2.0.0 class Backprop a where -- | "Zero out" all components of a value. For scalar values, this -- should just be @'const' 0@. For vectors and matrices, this should -- set all components to zero, the additive identity. -- -- Should be idempotent: -- -- * @'zero' '.' 'zero' = 'zero'@ -- -- Should be as /lazy/ as possible. This behavior is observed for -- all instances provided by this library. -- -- See 'zeroNum' for a pre-built definition for instances of 'Num' and -- 'zeroFunctor' for a definition for instances of 'Functor'. If left -- blank, will automatically be 'genericZero', a pre-built definition -- for instances of 'GHC.Generic' whose fields are all themselves -- instances of 'Backprop'. zero :: a -> a -- | Add together two values of a type. To combine contributions of -- gradients, so should be information-preserving: -- -- * @'add' x ('zero' y) = x@ -- -- * @'add' ('zero' x) y = y@ -- -- Should be as /strict/ as possible. This behavior is observed for -- all instances provided by this library. -- -- See 'addNum' for a pre-built definition for instances of 'Num' and -- 'addFunctor' for a definition for instances of 'Functor'. If left -- blank, will automatically be 'genericAdd', a pre-built definition -- for instances of 'GHC.Generic' with one constructor whose fields are -- all themselves instances of 'Backprop'. add :: a -> a -> a -- | "One" all components of a value. For scalar values, this should -- just be @'const' 1@. For vectors and matrices, this should set all -- components to one, the multiplicative identity. -- -- Should be idempotent: -- -- * @'one' '.' 'one' = 'one'@ -- -- Should be as /lazy/ as possible. This behavior is observed for -- all instances provided by this library. -- -- See 'oneNum' for a pre-built definition for instances of 'Num' and -- 'oneFunctor' for a definition for instances of 'Functor'. If left -- blank, will automatically be 'genericOne', a pre-built definition -- for instances of 'GHC.Generic' whose fields are all themselves -- instances of 'Backprop'. one :: a -> a default zero :: (Generic a, GZero (Rep a)) => a -> a zero = genericZero {-# INLINE zero #-} default add :: (Generic a, GAdd (Rep a)) => a -> a -> a add = genericAdd {-# INLINE add #-} default one :: (Generic a, GOne (Rep a)) => a -> a one = genericOne {-# INLINE one #-} -- | 'zero' using GHC Generics; works if all fields are instances of -- 'Backprop'. genericZero :: (Generic a, GZero (Rep a)) => a -> a genericZero = to . gzero . from {-# INLINE genericZero #-} -- | 'add' using GHC Generics; works if all fields are instances of -- 'Backprop', but only for values with single constructors. genericAdd :: (Generic a, GAdd (Rep a)) => a -> a -> a genericAdd x y = to $ gadd (from x) (from y) {-# INLINE genericAdd #-} -- | 'one' using GHC Generics; works if all fields are instaces of -- 'Backprop'. genericOne :: (Generic a, GOne (Rep a)) => a -> a genericOne = to . gone . from {-# INLINE genericOne #-} -- | 'zero' for instances of 'Num'. -- -- Is lazy in its argument. zeroNum :: Num a => a -> a zeroNum _ = 0 {-# INLINE zeroNum #-} -- | 'add' for instances of 'Num'. addNum :: Num a => a -> a -> a addNum = (+) {-# INLINE addNum #-} -- | 'one' for instances of 'Num'. -- -- Is lazy in its argument. oneNum :: Num a => a -> a oneNum _ = 1 {-# INLINE oneNum #-} -- | 'zero' for instances of 'VG.Vector'. zeroVec :: (VG.Vector v a, Backprop a) => v a -> v a zeroVec = VG.map zero {-# INLINE zeroVec #-} -- | 'add' for instances of 'VG.Vector'. Automatically pads the end of the -- shorter vector with zeroes. addVec :: (VG.Vector v a, Backprop a) => v a -> v a -> v a addVec x y = case compare lX lY of LT -> let (y1,y2) = VG.splitAt (lY - lX) y in VG.zipWith add x y1 VG.++ y2 EQ -> VG.zipWith add x y GT -> let (x1,x2) = VG.splitAt (lX - lY) x in VG.zipWith add x1 y VG.++ x2 where lX = VG.length x lY = VG.length y -- | 'one' for instances of 'VG.Vector'. oneVec :: (VG.Vector v a, Backprop a) => v a -> v a oneVec = VG.map one {-# INLINE oneVec #-} -- | 'zero' for 'Functor' instances. zeroFunctor :: (Functor f, Backprop a) => f a -> f a zeroFunctor = fmap zero {-# INLINE zeroFunctor #-} -- | 'add' for instances of 'IsList'. Automatically pads the end of the -- "shorter" value with zeroes. addIsList :: (IsList a, Backprop (Item a)) => a -> a -> a addIsList = addAsList toList fromList {-# INLINE addIsList #-} -- | 'add' for types that are isomorphic to a list. -- Automatically pads the end of the "shorter" value with zeroes. addAsList :: Backprop b => (a -> [b]) -- ^ convert to list (should form isomorphism) -> ([b] -> a) -- ^ convert from list (should form isomorphism) -> a -> a -> a addAsList f g x y = g $ go (f x) (f y) where go = \case [] -> id o@(x':xs) -> \case [] -> o y':ys -> add x' y' : go xs ys -- | 'one' for instances of 'Functor'. oneFunctor :: (Functor f, Backprop a) => f a -> f a oneFunctor = fmap one {-# INLINE oneFunctor #-} -- | Helper class for automatically deriving 'zero' using GHC Generics. class GZero f where gzero :: f t -> f t instance Backprop a => GZero (K1 i a) where gzero (K1 x) = K1 (zero x) {-# INLINE gzero #-} instance (GZero f, GZero g) => GZero (f :*: g) where gzero (x :*: y) = gzero x :*: gzero y {-# INLINE gzero #-} instance (GZero f, GZero g) => GZero (f :+: g) where gzero (L1 x) = L1 (gzero x) gzero (R1 x) = R1 (gzero x) {-# INLINE gzero #-} instance GZero V1 where gzero = \case {} {-# INLINE gzero #-} instance GZero U1 where gzero _ = U1 {-# INLINE gzero #-} instance GZero f => GZero (M1 i c f) where gzero (M1 x) = M1 (gzero x) {-# INLINE gzero #-} instance GZero f => GZero (f :.: g) where gzero (Comp1 x) = Comp1 (gzero x) {-# INLINE gzero #-} -- | Helper class for automatically deriving 'add' using GHC Generics. class GAdd f where gadd :: f t -> f t -> f t instance Backprop a => GAdd (K1 i a) where gadd (K1 x) (K1 y) = K1 (add x y) {-# INLINE gadd #-} instance (GAdd f, GAdd g) => GAdd (f :*: g) where gadd (x1 :*: y1) (x2 :*: y2) = x3 :*: y3 where !x3 = gadd x1 x2 !y3 = gadd y1 y2 {-# INLINE gadd #-} instance GAdd V1 where gadd = \case {} {-# INLINE gadd #-} instance GAdd U1 where gadd _ _ = U1 {-# INLINE gadd #-} instance GAdd f => GAdd (M1 i c f) where gadd (M1 x) (M1 y) = M1 (gadd x y) {-# INLINE gadd #-} instance GAdd f => GAdd (f :.: g) where gadd (Comp1 x) (Comp1 y) = Comp1 (gadd x y) {-# INLINE gadd #-} -- | Helper class for automatically deriving 'one' using GHC Generics. class GOne f where gone :: f t -> f t instance Backprop a => GOne (K1 i a) where gone (K1 x) = K1 (one x) {-# INLINE gone #-} instance (GOne f, GOne g) => GOne (f :*: g) where gone (x :*: y) = gone x :*: gone y {-# INLINE gone #-} instance (GOne f, GOne g) => GOne (f :+: g) where gone (L1 x) = L1 (gone x) gone (R1 x) = R1 (gone x) {-# INLINE gone #-} instance GOne V1 where gone = \case {} {-# INLINE gone #-} instance GOne U1 where gone _ = U1 {-# INLINE gone #-} instance GOne f => GOne (M1 i c f) where gone (M1 x) = M1 (gone x) {-# INLINE gone #-} instance GOne f => GOne (f :.: g) where gone (Comp1 x) = Comp1 (gone x) {-# INLINE gone #-} instance Backprop Int where zero = zeroNum {-# INLINE zero #-} add = addNum {-# INLINE add #-} one = oneNum {-# INLINE one #-} instance Backprop Integer where zero = zeroNum {-# INLINE zero #-} add = addNum {-# INLINE add #-} one = oneNum {-# INLINE one #-} instance Integral a => Backprop (Ratio a) where zero = zeroNum {-# INLINE zero #-} add = addNum {-# INLINE add #-} one = oneNum {-# INLINE one #-} instance RealFloat a => Backprop (Complex a) where zero = zeroNum {-# INLINE zero #-} add = addNum {-# INLINE add #-} one = oneNum {-# INLINE one #-} instance Backprop Float where zero = zeroNum {-# INLINE zero #-} add = addNum {-# INLINE add #-} one = oneNum {-# INLINE one #-} instance Backprop Double where zero = zeroNum {-# INLINE zero #-} add = addNum {-# INLINE add #-} one = oneNum {-# INLINE one #-} instance Backprop a => Backprop (V.Vector a) where zero = zeroVec {-# INLINE zero #-} add = addVec {-# INLINE add #-} one = oneVec {-# INLINE one #-} instance (VU.Unbox a, Backprop a) => Backprop (VU.Vector a) where zero = zeroVec {-# INLINE zero #-} add = addVec {-# INLINE add #-} one = oneVec {-# INLINE one #-} instance (VS.Storable a, Backprop a) => Backprop (VS.Vector a) where zero = zeroVec {-# INLINE zero #-} add = addVec {-# INLINE add #-} one = oneVec {-# INLINE one #-} instance (VP.Prim a, Backprop a) => Backprop (VP.Vector a) where zero = zeroVec {-# INLINE zero #-} add = addVec {-# INLINE add #-} one = oneVec {-# INLINE one #-} -- | 'add' assumes the shorter list has trailing zeroes, and the result has -- the length of the longest input. instance Backprop a => Backprop [a] where zero = zeroFunctor {-# INLINE zero #-} add = addIsList {-# INLINE add #-} one = oneFunctor {-# INLINE one #-} -- | 'add' assumes the shorter list has trailing zeroes, and the result has -- the length of the longest input. instance Backprop a => Backprop (NonEmpty a) where zero = zeroFunctor {-# INLINE zero #-} add = addIsList {-# INLINE add #-} one = oneFunctor {-# INLINE one #-} -- | 'add' assumes the shorter sequence has trailing zeroes, and the result -- has the length of the longest input. instance Backprop a => Backprop (Seq.Seq a) where zero = zeroFunctor {-# INLINE zero #-} add = addIsList {-# INLINE add #-} one = oneFunctor {-# INLINE one #-} -- | 'Nothing' is treated the same as @'Just' 0@. However, 'zero', 'add', -- and 'one' preserve 'Nothing' if all inputs are also 'Nothing'. instance Backprop a => Backprop (Maybe a) where zero = zeroFunctor {-# INLINE zero #-} add x y = asum [ add <$> x <*> y , x , y ] {-# INLINE add #-} one = oneFunctor {-# INLINE one #-} -- | 'add' is strict, but 'zero' and 'one' are lazy in their arguments. instance Backprop () where zero _ = () add () () = () one _ = () -- | 'add' is strict instance (Backprop a, Backprop b) => Backprop (a, b) where zero (x, y) = (zero x, zero y) {-# INLINE zero #-} add (x1, y1) (x2, y2) = (x3, y3) where !x3 = add x1 x2 !y3 = add y1 y2 {-# INLINE add #-} one (x, y) = (one x, one y) {-# INLINE one #-} -- | 'add' is strict instance (Backprop a, Backprop b, Backprop c) => Backprop (a, b, c) where zero (x, y, z) = (zero x, zero y, zero z) {-# INLINE zero #-} add (x1, y1, z1) (x2, y2, z2) = (x3, y3, z3) where !x3 = add x1 x2 !y3 = add y1 y2 !z3 = add z1 z2 {-# INLINE add #-} one (x, y, z) = (one x, one y, one z) {-# INLINE one #-} -- | 'add' is strict instance (Backprop a, Backprop b, Backprop c, Backprop d) => Backprop (a, b, c, d) where zero (x, y, z, w) = (zero x, zero y, zero z, zero w) {-# INLINE zero #-} add (x1, y1, z1, w1) (x2, y2, z2, w2) = (x3, y3, z3, w3) where !x3 = add x1 x2 !y3 = add y1 y2 !z3 = add z1 z2 !w3 = add w1 w2 {-# INLINE add #-} one (x, y, z, w) = (one x, one y, one z, one w) {-# INLINE one #-} -- | 'add' is strict instance (Backprop a, Backprop b, Backprop c, Backprop d, Backprop e) => Backprop (a, b, c, d, e) where zero (x, y, z, w, v) = (zero x, zero y, zero z, zero w, zero v) {-# INLINE zero #-} add (x1, y1, z1, w1, v1) (x2, y2, z2, w2, v2) = (x3, y3, z3, w3, v3) where !x3 = add x1 x2 !y3 = add y1 y2 !z3 = add z1 z2 !w3 = add w1 w2 !v3 = add v1 v2 {-# INLINE add #-} one (x, y, z, w, v) = (one x, one y, one z, one w, one v) {-# INLINE one #-} instance Backprop a => Backprop (Identity a) where zero (Identity x) = Identity (zero x) {-# INLINE zero #-} add (Identity x) (Identity y) = Identity (add x y) {-# INLINE add #-} one (Identity x) = Identity (one x) {-# INLINE one #-} instance Backprop a => Backprop (I a) where zero (I x) = I (zero x) {-# INLINE zero #-} add (I x) (I y) = I (add x y) {-# INLINE add #-} one (I x) = I (one x) {-# INLINE one #-} -- | 'add' is strict, but 'zero' and 'one' are lazy in their arguments. instance Backprop (Proxy a) where zero _ = Proxy {-# INLINE zero #-} add Proxy Proxy = Proxy {-# INLINE add #-} one _ = Proxy {-# INLINE one #-} instance Backprop Void where zero = \case {} {-# INLINE zero #-} add = \case {} {-# INLINE add #-} one = \case {} {-# INLINE one #-} -- | 'zero' and 'one' replace all current values, and 'add' merges keys -- from both maps, adding in the case of double-occurrences. instance (Backprop a, Ord k) => Backprop (M.Map k a) where zero = zeroFunctor {-# INLINE zero #-} add = M.unionWith add {-# INLINE add #-} one = oneFunctor {-# INLINE one #-} -- | 'zero' and 'one' replace all current values, and 'add' merges keys -- from both maps, adding in the case of double-occurrences. instance (Backprop a) => Backprop (IM.IntMap a) where zero = zeroFunctor {-# INLINE zero #-} add = IM.unionWith add {-# INLINE add #-} one = oneFunctor {-# INLINE one #-} instance ListC (Backprop <$> (f <$> as)) => Backprop (Prod f as) where zero = \case Ø -> Ø x :< xs -> zero x :< zero xs {-# INLINE zero #-} add = \case Ø -> \case Ø -> Ø x :< xs -> \case y :< ys -> add x y :< add xs ys {-# INLINE add #-} one = \case Ø -> Ø x :< xs -> one x :< one xs {-# INLINE one #-} instance M.MaybeC (Backprop M.<$> (f M.<$> a)) => Backprop (Option f a) where zero = \case Nothing_ -> Nothing_ Just_ x -> Just_ (zero x) {-# INLINE zero #-} add = \case Nothing_ -> \case Nothing_ -> Nothing_ Just_ x -> \case Just_ y -> Just_ (add x y) {-# INLINE add #-} one = \case Nothing_ -> Nothing_ Just_ x -> Just_ (one x) {-# INLINE one #-}