{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.Backprop.Class (
Backprop(..)
, zeroNum, addNum, oneNum
, zeroVec, addVec, oneVec
, zeroFunctor, addIsList, addAsList, oneFunctor
, genericZero, genericAdd, genericOne
, GZero(..), GAdd(..), GOne(..)
) where
import Data.Complex
import Data.Foldable hiding (toList)
import Data.Functor.Identity
import Data.List.NonEmpty (NonEmpty(..))
import Data.Proxy
import Data.Ratio
import Data.Type.Combinator hiding ((:.:), Comp1)
import Data.Type.Option
import Data.Type.Product hiding (toList)
import Data.Void
import GHC.Exts
import GHC.Generics
import Type.Family.List
import qualified Data.IntMap as IM
import qualified Data.Map as M
import qualified Data.Sequence as Seq
import qualified Data.Vector as V
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Primitive as VP
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import qualified Type.Family.Maybe as M
class Backprop a where
zero :: a -> a
add :: a -> a -> a
one :: a -> a
default zero :: (Generic a, GZero (Rep a)) => a -> a
zero = genericZero
{-# INLINE zero #-}
default add :: (Generic a, GAdd (Rep a)) => a -> a -> a
add = genericAdd
{-# INLINE add #-}
default one :: (Generic a, GOne (Rep a)) => a -> a
one = genericOne
{-# INLINE one #-}
genericZero :: (Generic a, GZero (Rep a)) => a -> a
genericZero = to . gzero . from
{-# INLINE genericZero #-}
genericAdd :: (Generic a, GAdd (Rep a)) => a -> a -> a
genericAdd x y = to $ gadd (from x) (from y)
{-# INLINE genericAdd #-}
genericOne :: (Generic a, GOne (Rep a)) => a -> a
genericOne = to . gone . from
{-# INLINE genericOne #-}
zeroNum :: Num a => a -> a
zeroNum _ = 0
{-# INLINE zeroNum #-}
addNum :: Num a => a -> a -> a
addNum = (+)
{-# INLINE addNum #-}
oneNum :: Num a => a -> a
oneNum _ = 1
{-# INLINE oneNum #-}
zeroVec :: (VG.Vector v a, Backprop a) => v a -> v a
zeroVec = VG.map zero
{-# INLINE zeroVec #-}
addVec :: (VG.Vector v a, Backprop a) => v a -> v a -> v a
addVec x y = case compare lX lY of
LT -> let (y1,y2) = VG.splitAt (lY - lX) y
in VG.zipWith add x y1 VG.++ y2
EQ -> VG.zipWith add x y
GT -> let (x1,x2) = VG.splitAt (lX - lY) x
in VG.zipWith add x1 y VG.++ x2
where
lX = VG.length x
lY = VG.length y
oneVec :: (VG.Vector v a, Backprop a) => v a -> v a
oneVec = VG.map one
{-# INLINE oneVec #-}
zeroFunctor :: (Functor f, Backprop a) => f a -> f a
zeroFunctor = fmap zero
{-# INLINE zeroFunctor #-}
addIsList :: (IsList a, Backprop (Item a)) => a -> a -> a
addIsList = addAsList toList fromList
{-# INLINE addIsList #-}
addAsList
:: Backprop b
=> (a -> [b])
-> ([b] -> a)
-> a
-> a
-> a
addAsList f g x y = g $ go (f x) (f y)
where
go = \case
[] -> id
o@(x':xs) -> \case
[] -> o
y':ys -> add x' y' : go xs ys
oneFunctor :: (Functor f, Backprop a) => f a -> f a
oneFunctor = fmap one
{-# INLINE oneFunctor #-}
class GZero f where
gzero :: f t -> f t
instance Backprop a => GZero (K1 i a) where
gzero (K1 x) = K1 (zero x)
{-# INLINE gzero #-}
instance (GZero f, GZero g) => GZero (f :*: g) where
gzero (x :*: y) = gzero x :*: gzero y
{-# INLINE gzero #-}
instance (GZero f, GZero g) => GZero (f :+: g) where
gzero (L1 x) = L1 (gzero x)
gzero (R1 x) = R1 (gzero x)
{-# INLINE gzero #-}
instance GZero V1 where
gzero = \case {}
{-# INLINE gzero #-}
instance GZero U1 where
gzero _ = U1
{-# INLINE gzero #-}
instance GZero f => GZero (M1 i c f) where
gzero (M1 x) = M1 (gzero x)
{-# INLINE gzero #-}
instance GZero f => GZero (f :.: g) where
gzero (Comp1 x) = Comp1 (gzero x)
{-# INLINE gzero #-}
class GAdd f where
gadd :: f t -> f t -> f t
instance Backprop a => GAdd (K1 i a) where
gadd (K1 x) (K1 y) = K1 (add x y)
{-# INLINE gadd #-}
instance (GAdd f, GAdd g) => GAdd (f :*: g) where
gadd (x1 :*: y1) (x2 :*: y2) = x3 :*: y3
where
!x3 = gadd x1 x2
!y3 = gadd y1 y2
{-# INLINE gadd #-}
instance GAdd V1 where
gadd = \case {}
{-# INLINE gadd #-}
instance GAdd U1 where
gadd _ _ = U1
{-# INLINE gadd #-}
instance GAdd f => GAdd (M1 i c f) where
gadd (M1 x) (M1 y) = M1 (gadd x y)
{-# INLINE gadd #-}
instance GAdd f => GAdd (f :.: g) where
gadd (Comp1 x) (Comp1 y) = Comp1 (gadd x y)
{-# INLINE gadd #-}
class GOne f where
gone :: f t -> f t
instance Backprop a => GOne (K1 i a) where
gone (K1 x) = K1 (one x)
{-# INLINE gone #-}
instance (GOne f, GOne g) => GOne (f :*: g) where
gone (x :*: y) = gone x :*: gone y
{-# INLINE gone #-}
instance (GOne f, GOne g) => GOne (f :+: g) where
gone (L1 x) = L1 (gone x)
gone (R1 x) = R1 (gone x)
{-# INLINE gone #-}
instance GOne V1 where
gone = \case {}
{-# INLINE gone #-}
instance GOne U1 where
gone _ = U1
{-# INLINE gone #-}
instance GOne f => GOne (M1 i c f) where
gone (M1 x) = M1 (gone x)
{-# INLINE gone #-}
instance GOne f => GOne (f :.: g) where
gone (Comp1 x) = Comp1 (gone x)
{-# INLINE gone #-}
instance Backprop Int where
zero = zeroNum
{-# INLINE zero #-}
add = addNum
{-# INLINE add #-}
one = oneNum
{-# INLINE one #-}
instance Backprop Integer where
zero = zeroNum
{-# INLINE zero #-}
add = addNum
{-# INLINE add #-}
one = oneNum
{-# INLINE one #-}
instance Integral a => Backprop (Ratio a) where
zero = zeroNum
{-# INLINE zero #-}
add = addNum
{-# INLINE add #-}
one = oneNum
{-# INLINE one #-}
instance RealFloat a => Backprop (Complex a) where
zero = zeroNum
{-# INLINE zero #-}
add = addNum
{-# INLINE add #-}
one = oneNum
{-# INLINE one #-}
instance Backprop Float where
zero = zeroNum
{-# INLINE zero #-}
add = addNum
{-# INLINE add #-}
one = oneNum
{-# INLINE one #-}
instance Backprop Double where
zero = zeroNum
{-# INLINE zero #-}
add = addNum
{-# INLINE add #-}
one = oneNum
{-# INLINE one #-}
instance Backprop a => Backprop (V.Vector a) where
zero = zeroVec
{-# INLINE zero #-}
add = addVec
{-# INLINE add #-}
one = oneVec
{-# INLINE one #-}
instance (VU.Unbox a, Backprop a) => Backprop (VU.Vector a) where
zero = zeroVec
{-# INLINE zero #-}
add = addVec
{-# INLINE add #-}
one = oneVec
{-# INLINE one #-}
instance (VS.Storable a, Backprop a) => Backprop (VS.Vector a) where
zero = zeroVec
{-# INLINE zero #-}
add = addVec
{-# INLINE add #-}
one = oneVec
{-# INLINE one #-}
instance (VP.Prim a, Backprop a) => Backprop (VP.Vector a) where
zero = zeroVec
{-# INLINE zero #-}
add = addVec
{-# INLINE add #-}
one = oneVec
{-# INLINE one #-}
instance Backprop a => Backprop [a] where
zero = zeroFunctor
{-# INLINE zero #-}
add = addIsList
{-# INLINE add #-}
one = oneFunctor
{-# INLINE one #-}
instance Backprop a => Backprop (NonEmpty a) where
zero = zeroFunctor
{-# INLINE zero #-}
add = addIsList
{-# INLINE add #-}
one = oneFunctor
{-# INLINE one #-}
instance Backprop a => Backprop (Seq.Seq a) where
zero = zeroFunctor
{-# INLINE zero #-}
add = addIsList
{-# INLINE add #-}
one = oneFunctor
{-# INLINE one #-}
instance Backprop a => Backprop (Maybe a) where
zero = zeroFunctor
{-# INLINE zero #-}
add x y = asum [ add <$> x <*> y
, x
, y
]
{-# INLINE add #-}
one = oneFunctor
{-# INLINE one #-}
instance Backprop () where
zero _ = ()
add () () = ()
one _ = ()
instance (Backprop a, Backprop b) => Backprop (a, b) where
zero (x, y) = (zero x, zero y)
{-# INLINE zero #-}
add (x1, y1) (x2, y2) = (x3, y3)
where
!x3 = add x1 x2
!y3 = add y1 y2
{-# INLINE add #-}
one (x, y) = (one x, one y)
{-# INLINE one #-}
instance (Backprop a, Backprop b, Backprop c) => Backprop (a, b, c) where
zero (x, y, z) = (zero x, zero y, zero z)
{-# INLINE zero #-}
add (x1, y1, z1) (x2, y2, z2) = (x3, y3, z3)
where
!x3 = add x1 x2
!y3 = add y1 y2
!z3 = add z1 z2
{-# INLINE add #-}
one (x, y, z) = (one x, one y, one z)
{-# INLINE one #-}
instance (Backprop a, Backprop b, Backprop c, Backprop d) => Backprop (a, b, c, d) where
zero (x, y, z, w) = (zero x, zero y, zero z, zero w)
{-# INLINE zero #-}
add (x1, y1, z1, w1) (x2, y2, z2, w2) = (x3, y3, z3, w3)
where
!x3 = add x1 x2
!y3 = add y1 y2
!z3 = add z1 z2
!w3 = add w1 w2
{-# INLINE add #-}
one (x, y, z, w) = (one x, one y, one z, one w)
{-# INLINE one #-}
instance (Backprop a, Backprop b, Backprop c, Backprop d, Backprop e) => Backprop (a, b, c, d, e) where
zero (x, y, z, w, v) = (zero x, zero y, zero z, zero w, zero v)
{-# INLINE zero #-}
add (x1, y1, z1, w1, v1) (x2, y2, z2, w2, v2) = (x3, y3, z3, w3, v3)
where
!x3 = add x1 x2
!y3 = add y1 y2
!z3 = add z1 z2
!w3 = add w1 w2
!v3 = add v1 v2
{-# INLINE add #-}
one (x, y, z, w, v) = (one x, one y, one z, one w, one v)
{-# INLINE one #-}
instance Backprop a => Backprop (Identity a) where
zero (Identity x) = Identity (zero x)
{-# INLINE zero #-}
add (Identity x) (Identity y) = Identity (add x y)
{-# INLINE add #-}
one (Identity x) = Identity (one x)
{-# INLINE one #-}
instance Backprop a => Backprop (I a) where
zero (I x) = I (zero x)
{-# INLINE zero #-}
add (I x) (I y) = I (add x y)
{-# INLINE add #-}
one (I x) = I (one x)
{-# INLINE one #-}
instance Backprop (Proxy a) where
zero _ = Proxy
{-# INLINE zero #-}
add Proxy Proxy = Proxy
{-# INLINE add #-}
one _ = Proxy
{-# INLINE one #-}
instance Backprop Void where
zero = \case {}
{-# INLINE zero #-}
add = \case {}
{-# INLINE add #-}
one = \case {}
{-# INLINE one #-}
instance (Backprop a, Ord k) => Backprop (M.Map k a) where
zero = zeroFunctor
{-# INLINE zero #-}
add = M.unionWith add
{-# INLINE add #-}
one = oneFunctor
{-# INLINE one #-}
instance (Backprop a) => Backprop (IM.IntMap a) where
zero = zeroFunctor
{-# INLINE zero #-}
add = IM.unionWith add
{-# INLINE add #-}
one = oneFunctor
{-# INLINE one #-}
instance ListC (Backprop <$> (f <$> as)) => Backprop (Prod f as) where
zero = \case
Ø -> Ø
x :< xs -> zero x :< zero xs
{-# INLINE zero #-}
add = \case
Ø -> \case
Ø -> Ø
x :< xs -> \case
y :< ys -> add x y :< add xs ys
{-# INLINE add #-}
one = \case
Ø -> Ø
x :< xs -> one x :< one xs
{-# INLINE one #-}
instance M.MaybeC (Backprop M.<$> (f M.<$> a)) => Backprop (Option f a) where
zero = \case
Nothing_ -> Nothing_
Just_ x -> Just_ (zero x)
{-# INLINE zero #-}
add = \case
Nothing_ -> \case
Nothing_ -> Nothing_
Just_ x -> \case
Just_ y -> Just_ (add x y)
{-# INLINE add #-}
one = \case
Nothing_ -> Nothing_
Just_ x -> Just_ (one x)
{-# INLINE one #-}