{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Prelude.Backprop (
sum
, product
, length
, minimum
, maximum
, traverse
, fmap
, (<$>)
, pure
, liftA2
, liftA3
, coerce
) where
import Numeric.Backprop
import Prelude (Num(..), Fractional(..), Eq(..), Ord(..), Functor, Foldable, Traversable, Applicative, (.), ($))
import qualified Control.Applicative as P
import qualified Data.Coerce as C
import qualified Data.Foldable as P
import qualified Prelude as P
sum :: forall t a s. (Foldable t, Functor t, Backprop (t a), Backprop a, Num a, Reifies s W)
=> BVar s (t a)
-> BVar s a
sum = liftOp1 . op1 $ \xs ->
( P.sum xs
, (P.<$ xs)
)
{-# INLINE sum #-}
pure
:: forall t a s. (Foldable t, Applicative t, Backprop (t a), Backprop a, Reifies s W)
=> BVar s a
-> BVar s (t a)
pure = liftOp1 . op1 $ \x ->
( P.pure x
, P.foldl' add (zero x)
)
{-# INLINE pure #-}
product
:: forall t a s. (Foldable t, Functor t, Backprop (t a), Backprop a, Fractional a, Reifies s W)
=> BVar s (t a)
-> BVar s a
product = liftOp1 . op1 $ \xs ->
let p = P.product xs
in ( p
, \d -> (\x -> p * d / x) P.<$> xs
)
{-# INLINE product #-}
length
:: forall t a b s. (Foldable t, Backprop (t a), Backprop b, Num b, Reifies s W)
=> BVar s (t a)
-> BVar s b
length = liftOp1 . op1 $ \xs ->
( P.fromIntegral (P.length xs)
, P.const (zero xs)
)
{-# INLINE length #-}
minimum
:: forall t a s. (Foldable t, Functor t, Backprop a, Ord a, Backprop (t a), Reifies s W)
=> BVar s (t a)
-> BVar s a
minimum = liftOp1 . op1 $ \xs ->
let m = P.minimum xs
in ( m
, \d -> (\x -> if x == m then d else zero x) P.<$> xs
)
{-# INLINE minimum #-}
maximum
:: forall t a s. (Foldable t, Functor t, Backprop a, Ord a, Backprop (t a), Reifies s W)
=> BVar s (t a)
-> BVar s a
maximum = liftOp1 . op1 $ \xs ->
let m = P.maximum xs
in ( m
, \d -> (\x -> if x == m then d else zero x) P.<$> xs
)
{-# INLINE maximum #-}
fmap
:: forall f a b s. (Traversable f, Backprop a, Backprop b, Backprop (f b), Reifies s W)
=> (BVar s a -> BVar s b)
-> BVar s (f a)
-> BVar s (f b)
fmap f = collectVar . P.fmap f . sequenceVar
{-# INLINE fmap #-}
(<$>)
:: forall f a b s. (Traversable f, Backprop a, Backprop b, Backprop (f b), Reifies s W)
=> (BVar s a -> BVar s b)
-> BVar s (f a)
-> BVar s (f b)
(<$>) = fmap
{-# INLINE (<$>) #-}
traverse
:: forall t f a b s. (Traversable t, Applicative f, Foldable f, Backprop a, Backprop b, Backprop (f (t b)), Backprop (t b), Reifies s W)
=> (BVar s a -> f (BVar s b))
-> BVar s (t a)
-> BVar s (f (t b))
traverse f = collectVar
. P.fmap collectVar
. P.traverse f
. sequenceVar
{-# INLINE traverse #-}
liftA2
:: forall f a b c s.
( Traversable f
, Applicative f
, Backprop a, Backprop b, Backprop c, Backprop (f c)
, Reifies s W
)
=> (BVar s a -> BVar s b -> BVar s c)
-> BVar s (f a)
-> BVar s (f b)
-> BVar s (f c)
liftA2 f x y = collectVar $ f P.<$> sequenceVar x
P.<*> sequenceVar y
{-# INLINE liftA2 #-}
liftA3
:: forall f a b c d s.
( Traversable f
, Applicative f
, Backprop a, Backprop b, Backprop c, Backprop d, Backprop (f d)
, Reifies s W
)
=> (BVar s a -> BVar s b -> BVar s c -> BVar s d)
-> BVar s (f a)
-> BVar s (f b)
-> BVar s (f c)
-> BVar s (f d)
liftA3 f x y z = collectVar $ f P.<$> sequenceVar x
P.<*> sequenceVar y
P.<*> sequenceVar z
{-# INLINE liftA3 #-}
coerce
:: forall a b s. C.Coercible a b
=> BVar s a
-> BVar s b
coerce = coerceVar
{-# INLINE coerce #-}