{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

-- | Easy backpropagation when all variables have the same type.
--
-- @
-- data MyRecord a = ...
--   deriving (Functor, Foldable, Traversable)
--
-- deriving via (TraversableVar MyRecord a) instance HasGrad a => HasGrad (MyRecord a)
-- @
--
-- = Gradient type
-- One might excect gradient type to be @type Grad (MyRecord a) = MyRecord (Grad a)@, but it's not
-- the case, because record could contain additional members apart from @a@s, for example:
--
-- @
-- data MyPoint a = MyPoint
-- {
-- ,  pointLabel :: String
-- ,  pointX :: a
-- ,  pointY :: a
-- }
-- @
--
-- and @MyPoint (Grad a)@ can't be made @VectorSpace@. Gradient type @Grad (MyRecord a)@
-- is a newtype wrapper over @IntMap@
-- that is not exported.
module Downhill.BVar.Traversable
  ( -- * Backpropagate
    backpropTraversable,
    backpropTraversable_GradOnly,
    backpropTraversable_ValueAndGrad,

    -- * Split
    splitTraversable,

    -- * TraversableVar
    TraversableVar (..),
  )
where

import Control.Monad.Trans.State.Strict (State, evalState, get, put)
import Data.AdditiveGroup (AdditiveGroup, sumV)
import Data.Foldable (toList)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Kind (Type)
import Data.Maybe (fromMaybe)
import Data.VectorSpace (AdditiveGroup (negateV, zeroV, (^+^), (^-^)), VectorSpace (Scalar, (*^)))
import qualified Data.VectorSpace as VectorSpace
import Downhill.BVar (BVar (BVar, bvarGrad, bvarValue), backprop, var)
import Downhill.Grad
  ( Dual (evalGrad),
    Manifold (Grad, Tang), HasGrad
  )
import Downhill.Linear.BackGrad (BackGrad (BackGrad), castBackGrad, realNode)
import Downhill.Linear.Expr
  ( BasicVector (VecBuilder, identityBuilder, sumBuilder),
    Expr (ExprSum),
    SparseVector (unSparseVector),
    Term,
  )
import Downhill.Linear.Lift (lift1_sparse)
import GHC.Generics (Generic)
import Downhill.Metric (MetricTensor (evalMetric))

-- | Provides HasGrad instance for use in deriving via
newtype TraversableVar f a = TraversableVar {forall (f :: * -> *) a. TraversableVar f a -> f a
unTraversableVar :: f a}
  deriving stock (forall a b. a -> TraversableVar f b -> TraversableVar f a
forall a b. (a -> b) -> TraversableVar f a -> TraversableVar f b
forall (f :: * -> *) a b.
Functor f =>
a -> TraversableVar f b -> TraversableVar f a
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> TraversableVar f a -> TraversableVar f b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TraversableVar f b -> TraversableVar f a
$c<$ :: forall (f :: * -> *) a b.
Functor f =>
a -> TraversableVar f b -> TraversableVar f a
fmap :: forall a b. (a -> b) -> TraversableVar f a -> TraversableVar f b
$cfmap :: forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> TraversableVar f a -> TraversableVar f b
Functor, forall a. TraversableVar f a -> Bool
forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
forall (f :: * -> *) a.
(Foldable f, Eq a) =>
a -> TraversableVar f a -> Bool
forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
forall (f :: * -> *) m.
(Foldable f, Monoid m) =>
TraversableVar f m -> m
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Bool
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Int
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> [a]
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => TraversableVar f a -> a
$cproduct :: forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
sum :: forall a. Num a => TraversableVar f a -> a
$csum :: forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
minimum :: forall a. Ord a => TraversableVar f a -> a
$cminimum :: forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
maximum :: forall a. Ord a => TraversableVar f a -> a
$cmaximum :: forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
elem :: forall a. Eq a => a -> TraversableVar f a -> Bool
$celem :: forall (f :: * -> *) a.
(Foldable f, Eq a) =>
a -> TraversableVar f a -> Bool
length :: forall a. TraversableVar f a -> Int
$clength :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Int
null :: forall a. TraversableVar f a -> Bool
$cnull :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Bool
toList :: forall a. TraversableVar f a -> [a]
$ctoList :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> [a]
foldl1 :: forall a. (a -> a -> a) -> TraversableVar f a -> a
$cfoldl1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
foldr1 :: forall a. (a -> a -> a) -> TraversableVar f a -> a
$cfoldr1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b
$cfoldl' :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
foldl :: forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b
$cfoldl :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
$cfoldr' :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
foldr :: forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
$cfoldr :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
$cfoldMap' :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
$cfoldMap :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
fold :: forall m. Monoid m => TraversableVar f m -> m
$cfold :: forall (f :: * -> *) m.
(Foldable f, Monoid m) =>
TraversableVar f m -> m
Foldable, forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall {f :: * -> *}. Traversable f => Functor (TraversableVar f)
forall {f :: * -> *}. Traversable f => Foldable (TraversableVar f)
forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
TraversableVar f (m a) -> m (TraversableVar f a)
forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
TraversableVar f (f a) -> f (TraversableVar f a)
forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
sequence :: forall (m :: * -> *) a.
Monad m =>
TraversableVar f (m a) -> m (TraversableVar f a)
$csequence :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
TraversableVar f (m a) -> m (TraversableVar f a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
$cmapM :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
TraversableVar f (f a) -> f (TraversableVar f a)
$csequenceA :: forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
TraversableVar f (f a) -> f (TraversableVar f a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
$ctraverse :: forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
Traversable)

newtype TraversableMetric (f :: Type -> Type) g = TraversableMetric g
  deriving (forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (f :: * -> *) g x.
Rep (TraversableMetric f g) x -> TraversableMetric f g
forall (f :: * -> *) g x.
TraversableMetric f g -> Rep (TraversableMetric f g) x
$cto :: forall (f :: * -> *) g x.
Rep (TraversableMetric f g) x -> TraversableMetric f g
$cfrom :: forall (f :: * -> *) g x.
TraversableMetric f g -> Rep (TraversableMetric f g) x
Generic)

instance AdditiveGroup g => AdditiveGroup (TraversableMetric f g)

instance VectorSpace g => VectorSpace (TraversableMetric f g) where
  type Scalar (TraversableMetric f g) = Scalar g

instance MetricTensor p g => MetricTensor (TraversableVar f p) (TraversableMetric f g) where
  evalMetric :: TraversableMetric f g
-> Grad (TraversableVar f p) -> Tang (TraversableVar f p)
evalMetric (TraversableMetric g
m) (IntmapVector IntMap (Grad p)
da) =
    forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (forall p g. MetricTensor p g => g -> Grad p -> Tang p
evalMetric @p @g g
m) IntMap (Grad p)
da)

instance Manifold a => Manifold (TraversableVar f a) where
  type Tang (TraversableVar f a) = IntmapVector f (Tang a)
  type Grad (TraversableVar f a) = IntmapVector f (Grad a)

-- | @IntmapVector@ serves as a gradient of 'TraversableVar'.
newtype IntmapVector (f :: Type -> Type) v = IntmapVector {forall (f :: * -> *) v. IntmapVector f v -> IntMap v
unIntmapVector :: IntMap v}
  deriving (Int -> IntmapVector f v -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall (f :: * -> *) v. Show v => Int -> IntmapVector f v -> ShowS
forall (f :: * -> *) v. Show v => [IntmapVector f v] -> ShowS
forall (f :: * -> *) v. Show v => IntmapVector f v -> String
showList :: [IntmapVector f v] -> ShowS
$cshowList :: forall (f :: * -> *) v. Show v => [IntmapVector f v] -> ShowS
show :: IntmapVector f v -> String
$cshow :: forall (f :: * -> *) v. Show v => IntmapVector f v -> String
showsPrec :: Int -> IntmapVector f v -> ShowS
$cshowsPrec :: forall (f :: * -> *) v. Show v => Int -> IntmapVector f v -> ShowS
Show)

instance Manifold v => Manifold (IntmapVector f v) where
  type Tang (IntmapVector f v) = IntmapVector f (Tang v)
  type Grad (IntmapVector f v) = IntmapVector f (Grad v)

instance AdditiveGroup a => AdditiveGroup (IntmapVector f a) where
  zeroV :: IntmapVector f a
zeroV = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector forall a. IntMap a
IntMap.empty
  negateV :: IntmapVector f a -> IntmapVector f a
negateV (IntmapVector IntMap a
v) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall v. AdditiveGroup v => v -> v
negateV forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IntMap a
v)
  IntmapVector IntMap a
u ^+^ :: IntmapVector f a -> IntmapVector f a -> IntmapVector f a
^+^ IntmapVector IntMap a
v = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IntMap.unionWith forall v. AdditiveGroup v => v -> v -> v
(^+^) IntMap a
u IntMap a
v)
  IntmapVector IntMap a
u ^-^ :: IntmapVector f a -> IntmapVector f a -> IntmapVector f a
^-^ IntmapVector IntMap a
v = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a b c.
(Int -> a -> b -> Maybe c)
-> (IntMap a -> IntMap c)
-> (IntMap b -> IntMap c)
-> IntMap a
-> IntMap b
-> IntMap c
IntMap.mergeWithKey forall {a} {p}. AdditiveGroup a => p -> a -> a -> Maybe a
combine forall {a}. a -> a
only1 IntMap a -> IntMap a
only2 IntMap a
u IntMap a
v)
    where
      combine :: p -> a -> a -> Maybe a
combine p
_key a
x a
y = forall a. a -> Maybe a
Just (a
x forall v. AdditiveGroup v => v -> v -> v
^-^ a
y)
      only1 :: a -> a
only1 = forall {a}. a -> a
id
      only2 :: IntMap a -> IntMap a
only2 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. AdditiveGroup v => v -> v
negateV

instance VectorSpace v => VectorSpace (IntmapVector f v) where
  type Scalar (IntmapVector f v) = VectorSpace.Scalar v
  Scalar (IntmapVector f v)
a *^ :: Scalar (IntmapVector f v) -> IntmapVector f v -> IntmapVector f v
*^ (IntmapVector IntMap v
v) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar (IntmapVector f v)
a forall v. VectorSpace v => Scalar v -> v -> v
*^) IntMap v
v)

instance Dual dv v => Dual (IntmapVector f dv) (IntmapVector f v) where
  evalGrad :: IntmapVector f v -> IntmapVector f dv -> Scalar (IntmapVector f dv)
evalGrad (IntmapVector IntMap v
dv) (IntmapVector IntMap dv
v) = forall (f :: * -> *) v. (Foldable f, AdditiveGroup v) => f v -> v
sumV forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
IntMap.intersectionWith forall v dv. Dual v dv => dv -> v -> Scalar v
evalGrad IntMap v
dv IntMap dv
v

deriving via (IntMap v) instance Semigroup v => Semigroup (IntmapVector f v)

deriving via (IntMap v) instance Monoid v => Monoid (IntmapVector f v)

instance BasicVector v => BasicVector (IntmapVector f v) where
  type VecBuilder (IntmapVector f v) = IntmapVector f (VecBuilder v)
  sumBuilder :: VecBuilder (IntmapVector f v) -> IntmapVector f v
sumBuilder (IntmapVector IntMap (VecBuilder v)
v) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall v. BasicVector v => VecBuilder v -> v
sumBuilder IntMap (VecBuilder v)
v)
  identityBuilder :: IntmapVector f v -> VecBuilder (IntmapVector f v)
identityBuilder (IntmapVector IntMap v
x) = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IntMap v
x)

imap ::
  forall t a b.
  Traversable t =>
  (Int -> a -> b) ->
  t a ->
  t b
imap :: forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> b
mkBVar' t a
xs' = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> State Int b
getmkvar t a
xs') Int
0
  where
    getmkvar :: a -> State Int b
    getmkvar :: a -> State Int b
getmkvar a
x = do
      Int
index <- forall (m :: * -> *) s. Monad m => StateT s m s
get
      forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Int
index forall a. Num a => a -> a -> a
+ Int
1)
      forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> a -> b
mkBVar' Int
index a
x)

-- | Note that @splitTraversable@ won't be useful
-- for top level @BVar@, because the type @Grad (f a)@ is not exposed.
splitTraversable ::
  forall f r a.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a
  ) =>
  BVar r (f a) ->
  f (BVar r a)
splitTraversable :: forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
 HasGrad a) =>
BVar r (f a) -> f (BVar r a)
splitTraversable (BVar f a
xs BackGrad r (Grad (f a))
dxs) = f (BVar r a)
vars
  where
    vars :: f (BVar r a)
    vars :: f (BVar r a)
vars = forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> BVar r a
mkBVar f a
xs
    mkBVar :: Int -> a -> BVar r a
    mkBVar :: Int -> a -> BVar r a
mkBVar Int
index a
x =
      let mkBuilder :: VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
          mkBuilder :: VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder VecBuilder (Grad a)
dx = forall (f :: * -> *) v. IntMap v -> IntmapVector f v
IntmapVector (forall a. Int -> a -> IntMap a
IntMap.singleton Int
index VecBuilder (Grad a)
dx)
       in forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x (forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1_sparse VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder BackGrad r (Grad (f a))
dxs)

lift1_sparseT ::
  forall r a z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  BackGrad r a ->
  Term r (SparseVector z)
lift1_sparseT :: forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> BackGrad r a -> Term r (SparseVector z)
lift1_sparseT VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
f) = forall x. (x -> VecBuilder a) -> Term r x
f (VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector)

-- Not exported, because it is untested and hardly useful.
_joinTraversable ::
  forall f r a.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a
  ) =>
  f (BVar r a) ->
  BVar r (f a)
_joinTraversable :: forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
 HasGrad a) =>
f (BVar r a) -> BVar r (f a)
_joinTraversable f (BVar r a)
x = forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar f a
values (forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad BackGrad r (SparseVector (IntmapVector f (Grad a)))
node)
  where
    values :: f a
    values :: f a
values = forall r a. BVar r a -> a
bvarValue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BVar r a)
x
    grads :: f (BackGrad r (Grad a))
    grads :: f (BackGrad r (Grad a))
grads = forall r a. BVar r a -> BackGrad r (Grad a)
bvarGrad forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BVar r a)
x
    terms :: [Term r (SparseVector (IntmapVector f (Grad a)))]
    terms :: [Term r (SparseVector (IntmapVector f (Grad a)))]
terms = forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm f (BackGrad r (Grad a))
grads)
    mkTerm :: Int -> BackGrad r (Grad a) -> Term r (SparseVector (IntmapVector f (Grad a)))
    mkTerm :: Int
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm Int
index = forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> BackGrad r a -> Term r (SparseVector z)
lift1_sparseT (forall x. Int -> IntmapVector f x -> x
lookupIntMap Int
index)
    lookupIntMap :: Int -> IntmapVector f x -> x
    lookupIntMap :: forall x. Int -> IntmapVector f x -> x
lookupIntMap Int
key (IntmapVector IntMap x
intmap) = case forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
key IntMap x
intmap of
      Maybe x
Nothing -> forall a. HasCallStack => String -> a
error String
"Downhill BUG: Bad index in joinTraversable"
      Just x
value -> x
value
    node :: BackGrad r (SparseVector (IntmapVector f (Grad a)))
    node :: BackGrad r (SparseVector (IntmapVector f (Grad a)))
node = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [Term r (SparseVector (IntmapVector f (Grad a)))]
terms)

-- | @backpropTraversable one combine fun@
--
-- @one@ is a value to be backpropagated. In case of @p@ being scalar, set @one@
-- to 1 to compute unscaled gradient.
--
-- @combine@ is given value of a parameter and its gradient to construct result,
-- just like @zipWith@.
--
-- @fun@ is the function to be differentiated.
backpropTraversable ::
  forall f a b p.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    HasGrad p
  ) =>
  Grad p ->
  (a -> Grad a -> b) ->
  (forall r. f (BVar r a) -> BVar r p) ->
  f a ->
  f b
backpropTraversable :: forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one a -> Grad a -> b
combine forall r. f (BVar r a) -> BVar r p
fun f a
x = forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> b
makeResult f a
x
  where
    splitX :: f (BVar (Grad (f a)) a)
    splitX :: f (BVar (Grad (f a)) a)
splitX = forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
 HasGrad a) =>
BVar r (f a) -> f (BVar r a)
splitTraversable (forall a. a -> BVar (Grad a) a
var f a
x)

    y :: BVar (Grad (f a)) p
    y :: BVar (Grad (f a)) p
y = forall r. f (BVar r a) -> BVar r p
fun f (BVar (Grad (f a)) a)
splitX

    grad :: IntMap (Grad a)
    IntmapVector IntMap (Grad a)
grad = forall r a. (HasGrad a, BasicVector r) => BVar r a -> Grad a -> r
backprop BVar (Grad (f a)) p
y Grad p
one

    lookupGrad :: Int -> Grad a
lookupGrad Int
i = forall a. a -> Maybe a -> a
fromMaybe forall v. AdditiveGroup v => v
zeroV (forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
i IntMap (Grad a)
grad)

    makeResult :: Int -> a -> b
    makeResult :: Int -> a -> b
makeResult Int
i a
x' = a -> Grad a -> b
combine a
x' (Int -> Grad a
lookupGrad Int
i)

{-# ANN backpropTraversable_GradOnly "HLint: ignore Use camelCase" #-}

-- | Like 'backpropTraversable', but returns gradient only.
backpropTraversable_GradOnly ::
  forall f a p.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    HasGrad p
  ) =>
  Grad p ->
  (forall r. f (BVar r a) -> BVar r p) ->
  f a ->
  f (Grad a)
backpropTraversable_GradOnly :: forall (f :: * -> *) a p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p) =>
Grad p -> (forall r. f (BVar r a) -> BVar r p) -> f a -> f (Grad a)
backpropTraversable_GradOnly Grad p
one = forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one forall {p} {p}. p -> p -> p
gradOnly
  where
    gradOnly :: p -> p -> p
gradOnly p
_value p
grad = p
grad

-- | 'backpropTraversable' specialized to return a pair of value and gradient.
{-# ANN backpropTraversable_ValueAndGrad "HLint: ignore Use camelCase" #-}
backpropTraversable_ValueAndGrad ::
  forall f a p.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    HasGrad p
  ) =>
  Grad p ->
  (forall r. f (BVar r a) -> BVar r p) ->
  f a ->
  f (a, Grad a)
backpropTraversable_ValueAndGrad :: forall (f :: * -> *) a p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p) =>
Grad p
-> (forall r. f (BVar r a) -> BVar r p) -> f a -> f (a, Grad a)
backpropTraversable_ValueAndGrad Grad p
one = forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one (,)