{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Downhill.BVar.Traversable
(
backpropTraversable,
backpropTraversable_GradOnly,
backpropTraversable_ValueAndGrad,
splitTraversable,
TraversableVar (..),
)
where
import Control.Monad.Trans.State.Strict (State, evalState, get, put)
import Data.AdditiveGroup (AdditiveGroup, sumV)
import Data.Foldable (toList)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Kind (Type)
import Data.Maybe (fromMaybe)
import Data.VectorSpace (AdditiveGroup (negateV, zeroV, (^+^), (^-^)), VectorSpace (Scalar, (*^)))
import qualified Data.VectorSpace as VectorSpace
import Downhill.BVar (BVar (BVar, bvarGrad, bvarValue), backprop, var)
import Downhill.Grad
( Dual (evalGrad),
Manifold (Grad, Tang), HasGrad
)
import Downhill.Linear.BackGrad (BackGrad (BackGrad), castBackGrad, realNode)
import Downhill.Linear.Expr
( BasicVector (VecBuilder, identityBuilder, sumBuilder),
Expr (ExprSum),
SparseVector (unSparseVector),
Term,
)
import Downhill.Linear.Lift (lift1_sparse)
import GHC.Generics (Generic)
import Downhill.Metric (MetricTensor (evalMetric))
newtype TraversableVar f a = TraversableVar {forall (f :: * -> *) a. TraversableVar f a -> f a
unTraversableVar :: f a}
deriving stock (forall a b. a -> TraversableVar f b -> TraversableVar f a
forall a b. (a -> b) -> TraversableVar f a -> TraversableVar f b
forall (f :: * -> *) a b.
Functor f =>
a -> TraversableVar f b -> TraversableVar f a
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> TraversableVar f a -> TraversableVar f b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TraversableVar f b -> TraversableVar f a
$c<$ :: forall (f :: * -> *) a b.
Functor f =>
a -> TraversableVar f b -> TraversableVar f a
fmap :: forall a b. (a -> b) -> TraversableVar f a -> TraversableVar f b
$cfmap :: forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> TraversableVar f a -> TraversableVar f b
Functor, forall a. TraversableVar f a -> Bool
forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
forall (f :: * -> *) a.
(Foldable f, Eq a) =>
a -> TraversableVar f a -> Bool
forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
forall (f :: * -> *) m.
(Foldable f, Monoid m) =>
TraversableVar f m -> m
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Bool
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Int
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> [a]
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => TraversableVar f a -> a
$cproduct :: forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
sum :: forall a. Num a => TraversableVar f a -> a
$csum :: forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
minimum :: forall a. Ord a => TraversableVar f a -> a
$cminimum :: forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
maximum :: forall a. Ord a => TraversableVar f a -> a
$cmaximum :: forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
elem :: forall a. Eq a => a -> TraversableVar f a -> Bool
$celem :: forall (f :: * -> *) a.
(Foldable f, Eq a) =>
a -> TraversableVar f a -> Bool
length :: forall a. TraversableVar f a -> Int
$clength :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Int
null :: forall a. TraversableVar f a -> Bool
$cnull :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Bool
toList :: forall a. TraversableVar f a -> [a]
$ctoList :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> [a]
foldl1 :: forall a. (a -> a -> a) -> TraversableVar f a -> a
$cfoldl1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
foldr1 :: forall a. (a -> a -> a) -> TraversableVar f a -> a
$cfoldr1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b
$cfoldl' :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
foldl :: forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b
$cfoldl :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
$cfoldr' :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
foldr :: forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
$cfoldr :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
$cfoldMap' :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
$cfoldMap :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
fold :: forall m. Monoid m => TraversableVar f m -> m
$cfold :: forall (f :: * -> *) m.
(Foldable f, Monoid m) =>
TraversableVar f m -> m
Foldable, forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall {f :: * -> *}. Traversable f => Functor (TraversableVar f)
forall {f :: * -> *}. Traversable f => Foldable (TraversableVar f)
forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
TraversableVar f (m a) -> m (TraversableVar f a)
forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
TraversableVar f (f a) -> f (TraversableVar f a)
forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
sequence :: forall (m :: * -> *) a.
Monad m =>
TraversableVar f (m a) -> m (TraversableVar f a)
$csequence :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
TraversableVar f (m a) -> m (TraversableVar f a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
$cmapM :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
TraversableVar f (f a) -> f (TraversableVar f a)
$csequenceA :: forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
TraversableVar f (f a) -> f (TraversableVar f a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
$ctraverse :: forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
Traversable)
newtype TraversableMetric (f :: Type -> Type) g = TraversableMetric g
deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (f :: * -> *) g x.
Rep (TraversableMetric f g) x -> TraversableMetric f g
forall (f :: * -> *) g x.
TraversableMetric f g -> Rep (TraversableMetric f g) x
$cto :: forall (f :: * -> *) g x.
Rep (TraversableMetric f g) x -> TraversableMetric f g
$cfrom :: forall (f :: * -> *) g x.
TraversableMetric f g -> Rep (TraversableMetric f g) x
Generic)
instance AdditiveGroup g => AdditiveGroup (TraversableMetric f g)
instance VectorSpace g => VectorSpace (TraversableMetric f g) where
type Scalar (TraversableMetric f g) = Scalar g
instance MetricTensor p g => MetricTensor (TraversableVar f p) (TraversableMetric f g) where
evalMetric :: TraversableMetric f g
-> Grad (TraversableVar f p) -> Tang (TraversableVar f p)
evalMetric (TraversableMetric g
m) (IntmapVector IntMap (Grad p)
da) =
forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @p @g g
m) IntMap (Grad p)
da)
instance Manifold a => Manifold (TraversableVar f a) where
type Tang (TraversableVar f a) = IntmapVector f (Tang a)
type Grad (TraversableVar f a) = IntmapVector f (Grad a)
newtype IntmapVector (f :: Type -> Type) v = IntmapVector {forall (f :: * -> *) v. IntmapVector f v -> IntMap v
unIntmapVector :: IntMap v}
deriving (Int -> IntmapVector f v -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (f :: * -> *) v. Show v => Int -> IntmapVector f v -> ShowS
forall (f :: * -> *) v. Show v => [IntmapVector f v] -> ShowS
forall (f :: * -> *) v. Show v => IntmapVector f v -> String
showList :: [IntmapVector f v] -> ShowS
$cshowList :: forall (f :: * -> *) v. Show v => [IntmapVector f v] -> ShowS
show :: IntmapVector f v -> String
$cshow :: forall (f :: * -> *) v. Show v => IntmapVector f v -> String
showsPrec :: Int -> IntmapVector f v -> ShowS
$cshowsPrec :: forall (f :: * -> *) v. Show v => Int -> IntmapVector f v -> ShowS
Show)
instance Manifold v => Manifold (IntmapVector f v) where
type Tang (IntmapVector f v) = IntmapVector f (Tang v)
type Grad (IntmapVector f v) = IntmapVector f (Grad v)
instance AdditiveGroup a => AdditiveGroup (IntmapVector f a) where
zeroV :: IntmapVector f a
zeroV = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector forall a. IntMap a
IntMap.empty
negateV :: IntmapVector f a -> IntmapVector f a
negateV (IntmapVector IntMap a
v) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall v. AdditiveGroup v => v -> v
negateV forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IntMap a
v)
IntmapVector IntMap a
u ^+^ :: IntmapVector f a -> IntmapVector f a -> IntmapVector f a
^+^ IntmapVector IntMap a
v = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IntMap.unionWith forall v. AdditiveGroup v => v -> v -> v
(^+^) IntMap a
u IntMap a
v)
IntmapVector IntMap a
u ^-^ :: IntmapVector f a -> IntmapVector f a -> IntmapVector f a
^-^ IntmapVector IntMap a
v = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a b c.
(Int -> a -> b -> Maybe c)
-> (IntMap a -> IntMap c)
-> (IntMap b -> IntMap c)
-> IntMap a
-> IntMap b
-> IntMap c
IntMap.mergeWithKey forall {a} {p}. AdditiveGroup a => p -> a -> a -> Maybe a
combine forall {a}. a -> a
only1 IntMap a -> IntMap a
only2 IntMap a
u IntMap a
v)
where
combine :: p -> a -> a -> Maybe a
combine p
_key a
x a
y = forall a. a -> Maybe a
Just (a
x forall v. AdditiveGroup v => v -> v -> v
^-^ a
y)
only1 :: a -> a
only1 = forall {a}. a -> a
id
only2 :: IntMap a -> IntMap a
only2 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. AdditiveGroup v => v -> v
negateV
instance VectorSpace v => VectorSpace (IntmapVector f v) where
type Scalar (IntmapVector f v) = VectorSpace.Scalar v
Scalar (IntmapVector f v)
a *^ :: Scalar (IntmapVector f v) -> IntmapVector f v -> IntmapVector f v
*^ (IntmapVector IntMap v
v) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (IntmapVector f v)
a forall v. VectorSpace v => Scalar v -> v -> v
*^) IntMap v
v)
instance Dual dv v => Dual (IntmapVector f dv) (IntmapVector f v) where
evalGrad :: IntmapVector f v -> IntmapVector f dv -> Scalar (IntmapVector f dv)
evalGrad (IntmapVector IntMap v
dv) (IntmapVector IntMap dv
v) = forall (f :: * -> *) v. (Foldable f, AdditiveGroup v) => f v -> v
sumV forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
IntMap.intersectionWith forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad IntMap v
dv IntMap dv
v
deriving via (IntMap v) instance Semigroup v => Semigroup (IntmapVector f v)
deriving via (IntMap v) instance Monoid v => Monoid (IntmapVector f v)
instance BasicVector v => BasicVector (IntmapVector f v) where
type VecBuilder (IntmapVector f v) = IntmapVector f (VecBuilder v)
sumBuilder :: VecBuilder (IntmapVector f v) -> IntmapVector f v
sumBuilder (IntmapVector IntMap (VecBuilder v)
v) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. BasicVector v => VecBuilder v -> v
sumBuilder IntMap (VecBuilder v)
v)
identityBuilder :: IntmapVector f v -> VecBuilder (IntmapVector f v)
identityBuilder (IntmapVector IntMap v
x) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IntMap v
x)
imap ::
forall t a b.
Traversable t =>
(Int -> a -> b) ->
t a ->
t b
imap :: forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> b
mkBVar' t a
xs' = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> State Int b
getmkvar t a
xs') Int
0
where
getmkvar :: a -> State Int b
getmkvar :: a -> State Int b
getmkvar a
x = do
Int
index <- forall (m :: * -> *) s. Monad m => StateT s m s
get
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Int
index forall a. Num a => a -> a -> a
+ Int
1)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> a -> b
mkBVar' Int
index a
x)
splitTraversable ::
forall f r a.
( Traversable f,
Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a
) =>
BVar r (f a) ->
f (BVar r a)
splitTraversable :: forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a) =>
BVar r (f a) -> f (BVar r a)
splitTraversable (BVar f a
xs BackGrad r (Grad (f a))
dxs) = f (BVar r a)
vars
where
vars :: f (BVar r a)
vars :: f (BVar r a)
vars = forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> BVar r a
mkBVar f a
xs
mkBVar :: Int -> a -> BVar r a
mkBVar :: Int -> a -> BVar r a
mkBVar Int
index a
x =
let mkBuilder :: VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder :: VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder VecBuilder (Grad a)
dx = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a. Int -> a -> IntMap a
IntMap.singleton Int
index VecBuilder (Grad a)
dx)
in forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x (forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1_sparse VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder BackGrad r (Grad (f a))
dxs)
lift1_sparseT ::
forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) ->
BackGrad r a ->
Term r (SparseVector z)
lift1_sparseT :: forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> BackGrad r a -> Term r (SparseVector z)
lift1_sparseT VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
f) = forall x. (x -> VecBuilder a) -> Term r x
f (VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector)
_joinTraversable ::
forall f r a.
( Traversable f,
Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a
) =>
f (BVar r a) ->
BVar r (f a)
_joinTraversable :: forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a) =>
f (BVar r a) -> BVar r (f a)
_joinTraversable f (BVar r a)
x = forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar f a
values (forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad BackGrad r (SparseVector (IntmapVector f (Grad a)))
node)
where
values :: f a
values :: f a
values = forall r a. BVar r a -> a
bvarValue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BVar r a)
x
grads :: f (BackGrad r (Grad a))
grads :: f (BackGrad r (Grad a))
grads = forall r a. BVar r a -> BackGrad r (Grad a)
bvarGrad forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BVar r a)
x
terms :: [Term r (SparseVector (IntmapVector f (Grad a)))]
terms :: [Term r (SparseVector (IntmapVector f (Grad a)))]
terms = forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm f (BackGrad r (Grad a))
grads)
mkTerm :: Int -> BackGrad r (Grad a) -> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm :: Int
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm Int
index = forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> BackGrad r a -> Term r (SparseVector z)
lift1_sparseT (forall x. Int -> IntmapVector f x -> x
lookupIntMap Int
index)
lookupIntMap :: Int -> IntmapVector f x -> x
lookupIntMap :: forall x. Int -> IntmapVector f x -> x
lookupIntMap Int
key (IntmapVector IntMap x
intmap) = case forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
key IntMap x
intmap of
Maybe x
Nothing -> forall a. HasCallStack => String -> a
error String
"Downhill BUG: Bad index in joinTraversable"
Just x
value -> x
value
node :: BackGrad r (SparseVector (IntmapVector f (Grad a)))
node :: BackGrad r (SparseVector (IntmapVector f (Grad a)))
node = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [Term r (SparseVector (IntmapVector f (Grad a)))]
terms)
backpropTraversable ::
forall f a b p.
( Traversable f,
Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a,
HasGrad p
) =>
Grad p ->
(a -> Grad a -> b) ->
(forall r. f (BVar r a) -> BVar r p) ->
f a ->
f b
backpropTraversable :: forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
HasGrad p) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one a -> Grad a -> b
combine forall r. f (BVar r a) -> BVar r p
fun f a
x = forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> b
makeResult f a
x
where
splitX :: f (BVar (Grad (f a)) a)
splitX :: f (BVar (Grad (f a)) a)
splitX = forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a) =>
BVar r (f a) -> f (BVar r a)
splitTraversable (forall a. a -> BVar (Grad a) a
var f a
x)
y :: BVar (Grad (f a)) p
y :: BVar (Grad (f a)) p
y = forall r. f (BVar r a) -> BVar r p
fun f (BVar (Grad (f a)) a)
splitX
grad :: IntMap (Grad a)
IntmapVector IntMap (Grad a)
grad = forall r a. (HasGrad a, BasicVector r) => BVar r a -> Grad a -> r
backprop BVar (Grad (f a)) p
y Grad p
one
lookupGrad :: Int -> Grad a
lookupGrad Int
i = forall a. a -> Maybe a -> a
fromMaybe forall v. AdditiveGroup v => v
zeroV (forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
i IntMap (Grad a)
grad)
makeResult :: Int -> a -> b
makeResult :: Int -> a -> b
makeResult Int
i a
x' = a -> Grad a -> b
combine a
x' (Int -> Grad a
lookupGrad Int
i)
{-# ANN backpropTraversable_GradOnly "HLint: ignore Use camelCase" #-}
backpropTraversable_GradOnly ::
forall f a p.
( Traversable f,
Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a,
HasGrad p
) =>
Grad p ->
(forall r. f (BVar r a) -> BVar r p) ->
f a ->
f (Grad a)
backpropTraversable_GradOnly :: forall (f :: * -> *) a p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
HasGrad p) =>
Grad p -> (forall r. f (BVar r a) -> BVar r p) -> f a -> f (Grad a)
backpropTraversable_GradOnly Grad p
one = forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
HasGrad p) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one forall {p} {p}. p -> p -> p
gradOnly
where
gradOnly :: p -> p -> p
gradOnly p
_value p
grad = p
grad
{-# ANN backpropTraversable_ValueAndGrad "HLint: ignore Use camelCase" #-}
backpropTraversable_ValueAndGrad ::
forall f a p.
( Traversable f,
Grad (f a) ~ Grad (TraversableVar f a),
HasGrad a,
HasGrad p
) =>
Grad p ->
(forall r. f (BVar r a) -> BVar r p) ->
f a ->
f (a, Grad a)
backpropTraversable_ValueAndGrad :: forall (f :: * -> *) a p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
HasGrad p) =>
Grad p
-> (forall r. f (BVar r a) -> BVar r p) -> f a -> f (a, Grad a)
backpropTraversable_ValueAndGrad Grad p
one = forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
HasGrad p) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one (,)