{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Prelude.Backprop.Explicit (
sum
, product
, length
, minimum
, maximum
, traverse
, fmap
, pure
, liftA2
, liftA3
, coerce
) where
import Numeric.Backprop.Explicit
import Prelude (Num(..), Fractional(..), Eq(..), Ord(..), Functor, Foldable, Traversable, Applicative, (.), ($))
import qualified Control.Applicative as P
import qualified Data.Coerce as C
import qualified Data.Foldable as P
import qualified Prelude as P
sum :: forall t a s. (Foldable t, Functor t, Num a, Reifies s W)
=> AddFunc (t a)
-> ZeroFunc a
-> BVar s (t a)
-> BVar s a
sum af zf = liftOp1 af zf . op1 $ \xs ->
( P.sum xs
, (P.<$ xs)
)
{-# INLINE sum #-}
pure
:: forall t a s. (Foldable t, Applicative t, Reifies s W)
=> AddFunc a
-> ZeroFunc a
-> ZeroFunc (t a)
-> BVar s a
-> BVar s (t a)
pure af zfa zf = liftOp1 af zf . op1 $ \x ->
( P.pure x
, P.foldl' (runAF af) (runZF zfa x)
)
{-# INLINE pure #-}
product
:: forall t a s. (Foldable t, Functor t, Fractional a, Reifies s W)
=> AddFunc (t a)
-> ZeroFunc a
-> BVar s (t a)
-> BVar s a
product af zf = liftOp1 af zf . op1 $ \xs ->
let p = P.product xs
in ( p
, \d -> (\x -> p * d / x) P.<$> xs
)
{-# INLINE product #-}
length
:: forall t a b s. (Foldable t, Num b, Reifies s W)
=> AddFunc (t a)
-> ZeroFunc (t a)
-> ZeroFunc b
-> BVar s (t a)
-> BVar s b
length af zfa zf = liftOp1 af zf . op1 $ \xs ->
( P.fromIntegral (P.length xs)
, P.const (runZF zfa xs)
)
{-# INLINE length #-}
minimum
:: forall t a s. (Foldable t, Functor t, Ord a, Reifies s W)
=> AddFunc (t a)
-> ZeroFunc a
-> BVar s (t a)
-> BVar s a
minimum af zf = liftOp1 af zf . op1 $ \xs ->
let m = P.minimum xs
in ( m
, \d -> (\x -> if x == m then d else runZF zf x) P.<$> xs
)
{-# INLINE minimum #-}
maximum
:: forall t a s. (Foldable t, Functor t, Ord a, Reifies s W)
=> AddFunc (t a)
-> ZeroFunc a
-> BVar s (t a)
-> BVar s a
maximum af zf = liftOp1 af zf . op1 $ \xs ->
let m = P.maximum xs
in ( m
, \d -> (\x -> if x == m then d else runZF zf x) P.<$> xs
)
{-# INLINE maximum #-}
fmap
:: forall f a b s. (Traversable f, Reifies s W)
=> AddFunc a
-> AddFunc b
-> ZeroFunc a
-> ZeroFunc b
-> ZeroFunc (f b)
-> (BVar s a -> BVar s b)
-> BVar s (f a)
-> BVar s (f b)
fmap afa afb zfa zfb zfbs f = collectVar afb zfb zfbs . P.fmap f . sequenceVar afa zfa
{-# INLINE fmap #-}
traverse
:: forall t f a b s. (Traversable t, Applicative f, Foldable f, Reifies s W)
=> AddFunc a
-> AddFunc b
-> AddFunc (t b)
-> ZeroFunc a
-> ZeroFunc b
-> ZeroFunc (t b)
-> ZeroFunc (f (t b))
-> (BVar s a -> f (BVar s b))
-> BVar s (t a)
-> BVar s (f (t b))
traverse afa afb aftb zfa zfb zftb zfftb f
= collectVar aftb zftb zfftb
. P.fmap (collectVar afb zfb zftb)
. P.traverse f
. sequenceVar afa zfa
{-# INLINE traverse #-}
liftA2
:: forall f a b c s.
( Traversable f
, Applicative f
, Reifies s W
)
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> ZeroFunc a
-> ZeroFunc b
-> ZeroFunc c
-> ZeroFunc (f c)
-> (BVar s a -> BVar s b -> BVar s c)
-> BVar s (f a)
-> BVar s (f b)
-> BVar s (f c)
liftA2 afa afb afc zfa zfb zfc zffc f x y
= collectVar afc zfc zffc
$ f P.<$> sequenceVar afa zfa x
P.<*> sequenceVar afb zfb y
{-# INLINE liftA2 #-}
liftA3
:: forall f a b c d s.
( Traversable f
, Applicative f
, Reifies s W
)
=> AddFunc a
-> AddFunc b
-> AddFunc c
-> AddFunc d
-> ZeroFunc a
-> ZeroFunc b
-> ZeroFunc c
-> ZeroFunc d
-> ZeroFunc (f d)
-> (BVar s a -> BVar s b -> BVar s c -> BVar s d)
-> BVar s (f a)
-> BVar s (f b)
-> BVar s (f c)
-> BVar s (f d)
liftA3 afa afb afc afd zfa zfb zfc zfd zffd f x y z
= collectVar afd zfd zffd
$ f P.<$> sequenceVar afa zfa x
P.<*> sequenceVar afb zfb y
P.<*> sequenceVar afc zfc z
{-# INLINE liftA3 #-}
coerce
:: forall a b s. C.Coercible a b
=> BVar s a
-> BVar s b
coerce = coerceVar
{-# INLINE coerce #-}