{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Easy backpropagation when all variables have the same type.
--
-- @
-- data MyRecord a = ...
--   deriving (Functor, Foldable, Traversable)
--
-- deriving via (TraversableVar MyRecord a) instance HasGrad a => HasGrad (MyRecord a)
-- @
-- 
-- = Gradient type
-- One might excect gradient type to be @type Grad (MyRecord a) = MyRecord (Grad a)@, but it's not
-- the case, because record could contain additional members apart from @a@s, for example:
--
-- @
-- data MyPoint a = MyPoint
-- {
-- ,  pointLabel :: String
-- ,  pointX :: a
-- ,  pointY :: a
-- }
-- @
--
-- and @MyPoint (Grad a)@ can't be made @VectorSpace@. Gradient type @Grad (MyRecord a)@
-- is a newtype wrapper over @IntMap@
-- that is not exported.
--

module Downhill.BVar.Traversable
  ( -- * Backpropagate
    backpropTraversable,
    backpropTraversable_GradOnly,
    backpropTraversable_ValueAndGrad,

     -- * Split
    splitTraversable,

    -- * TraversableVar
    TraversableVar (..),
  )
where

import Control.Monad.Trans.State.Strict (State, evalState, get, put)
import Data.AdditiveGroup (AdditiveGroup, sumV)
import Data.Foldable (toList)
import Data.IntMap (IntMap)
import qualified Data.IntMap as IntMap
import Data.Maybe (fromMaybe)
import Data.VectorSpace (AdditiveGroup (negateV, zeroV, (^+^), (^-^)), VectorSpace (Scalar, (*^)))
import qualified Data.VectorSpace as VectorSpace
import Downhill.BVar (BVar (BVar, bvarGrad, bvarValue), backprop, var)
import Downhill.Grad
  ( Dual (evalGrad),
    HasGrad (Grad, MScalar, Metric, Tang),
    MetricTensor
      ( MtCovector,
        MtVector,
        evalMetric
      ),
  )
import Downhill.Linear.BackGrad (BackGrad (BackGrad), castBackGrad, realNode)
import Downhill.Linear.Expr
  ( BasicVector (VecBuilder, sumBuilder),
    Expr (ExprSum),
    FullVector,
    SparseVector (unSparseVector),
    Term,
  )
import Downhill.Linear.Lift (lift1_sparse)
import GHC.Generics (Generic)

-- | Provides HasGrad instance for use in deriving via
newtype TraversableVar f a = TraversableVar {TraversableVar f a -> f a
unTraversableVar :: f a}
  deriving stock (a -> TraversableVar f b -> TraversableVar f a
(a -> b) -> TraversableVar f a -> TraversableVar f b
(forall a b. (a -> b) -> TraversableVar f a -> TraversableVar f b)
-> (forall a b. a -> TraversableVar f b -> TraversableVar f a)
-> Functor (TraversableVar f)
forall a b. a -> TraversableVar f b -> TraversableVar f a
forall a b. (a -> b) -> TraversableVar f a -> TraversableVar f b
forall (f :: * -> *) a b.
Functor f =>
a -> TraversableVar f b -> TraversableVar f a
forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> TraversableVar f a -> TraversableVar f b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TraversableVar f b -> TraversableVar f a
$c<$ :: forall (f :: * -> *) a b.
Functor f =>
a -> TraversableVar f b -> TraversableVar f a
fmap :: (a -> b) -> TraversableVar f a -> TraversableVar f b
$cfmap :: forall (f :: * -> *) a b.
Functor f =>
(a -> b) -> TraversableVar f a -> TraversableVar f b
Functor, TraversableVar f a -> Bool
(a -> m) -> TraversableVar f a -> m
(a -> b -> b) -> b -> TraversableVar f a -> b
(forall m. Monoid m => TraversableVar f m -> m)
-> (forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m)
-> (forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m)
-> (forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b)
-> (forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b)
-> (forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b)
-> (forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b)
-> (forall a. (a -> a -> a) -> TraversableVar f a -> a)
-> (forall a. (a -> a -> a) -> TraversableVar f a -> a)
-> (forall a. TraversableVar f a -> [a])
-> (forall a. TraversableVar f a -> Bool)
-> (forall a. TraversableVar f a -> Int)
-> (forall a. Eq a => a -> TraversableVar f a -> Bool)
-> (forall a. Ord a => TraversableVar f a -> a)
-> (forall a. Ord a => TraversableVar f a -> a)
-> (forall a. Num a => TraversableVar f a -> a)
-> (forall a. Num a => TraversableVar f a -> a)
-> Foldable (TraversableVar f)
forall a. Eq a => a -> TraversableVar f a -> Bool
forall a. Num a => TraversableVar f a -> a
forall a. Ord a => TraversableVar f a -> a
forall m. Monoid m => TraversableVar f m -> m
forall a. TraversableVar f a -> Bool
forall a. TraversableVar f a -> Int
forall a. TraversableVar f a -> [a]
forall a. (a -> a -> a) -> TraversableVar f a -> a
forall m a. Monoid m => (a -> m) -> TraversableVar f a -> m
forall b a. (b -> a -> b) -> b -> TraversableVar f a -> b
forall a b. (a -> b -> b) -> b -> TraversableVar f a -> b
forall (f :: * -> *) a.
(Foldable f, Eq a) =>
a -> TraversableVar f a -> Bool
forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
forall (f :: * -> *) m.
(Foldable f, Monoid m) =>
TraversableVar f m -> m
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Bool
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Int
forall (f :: * -> *) a. Foldable f => TraversableVar f a -> [a]
forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: TraversableVar f a -> a
$cproduct :: forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
sum :: TraversableVar f a -> a
$csum :: forall (f :: * -> *) a.
(Foldable f, Num a) =>
TraversableVar f a -> a
minimum :: TraversableVar f a -> a
$cminimum :: forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
maximum :: TraversableVar f a -> a
$cmaximum :: forall (f :: * -> *) a.
(Foldable f, Ord a) =>
TraversableVar f a -> a
elem :: a -> TraversableVar f a -> Bool
$celem :: forall (f :: * -> *) a.
(Foldable f, Eq a) =>
a -> TraversableVar f a -> Bool
length :: TraversableVar f a -> Int
$clength :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Int
null :: TraversableVar f a -> Bool
$cnull :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> Bool
toList :: TraversableVar f a -> [a]
$ctoList :: forall (f :: * -> *) a. Foldable f => TraversableVar f a -> [a]
foldl1 :: (a -> a -> a) -> TraversableVar f a -> a
$cfoldl1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
foldr1 :: (a -> a -> a) -> TraversableVar f a -> a
$cfoldr1 :: forall (f :: * -> *) a.
Foldable f =>
(a -> a -> a) -> TraversableVar f a -> a
foldl' :: (b -> a -> b) -> b -> TraversableVar f a -> b
$cfoldl' :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
foldl :: (b -> a -> b) -> b -> TraversableVar f a -> b
$cfoldl :: forall (f :: * -> *) b a.
Foldable f =>
(b -> a -> b) -> b -> TraversableVar f a -> b
foldr' :: (a -> b -> b) -> b -> TraversableVar f a -> b
$cfoldr' :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
foldr :: (a -> b -> b) -> b -> TraversableVar f a -> b
$cfoldr :: forall (f :: * -> *) a b.
Foldable f =>
(a -> b -> b) -> b -> TraversableVar f a -> b
foldMap' :: (a -> m) -> TraversableVar f a -> m
$cfoldMap' :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
foldMap :: (a -> m) -> TraversableVar f a -> m
$cfoldMap :: forall (f :: * -> *) m a.
(Foldable f, Monoid m) =>
(a -> m) -> TraversableVar f a -> m
fold :: TraversableVar f m -> m
$cfold :: forall (f :: * -> *) m.
(Foldable f, Monoid m) =>
TraversableVar f m -> m
Foldable, Functor (TraversableVar f)
Foldable (TraversableVar f)
Functor (TraversableVar f)
-> Foldable (TraversableVar f)
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> TraversableVar f a -> f (TraversableVar f b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    TraversableVar f (f a) -> f (TraversableVar f a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> TraversableVar f a -> m (TraversableVar f b))
-> (forall (m :: * -> *) a.
    Monad m =>
    TraversableVar f (m a) -> m (TraversableVar f a))
-> Traversable (TraversableVar f)
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (f :: * -> *). Traversable f => Functor (TraversableVar f)
forall (f :: * -> *). Traversable f => Foldable (TraversableVar f)
forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
TraversableVar f (m a) -> m (TraversableVar f a)
forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
TraversableVar f (f a) -> f (TraversableVar f a)
forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
forall (m :: * -> *) a.
Monad m =>
TraversableVar f (m a) -> m (TraversableVar f a)
forall (f :: * -> *) a.
Applicative f =>
TraversableVar f (f a) -> f (TraversableVar f a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
sequence :: TraversableVar f (m a) -> m (TraversableVar f a)
$csequence :: forall (f :: * -> *) (m :: * -> *) a.
(Traversable f, Monad m) =>
TraversableVar f (m a) -> m (TraversableVar f a)
mapM :: (a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
$cmapM :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Monad m) =>
(a -> m b) -> TraversableVar f a -> m (TraversableVar f b)
sequenceA :: TraversableVar f (f a) -> f (TraversableVar f a)
$csequenceA :: forall (f :: * -> *) (f :: * -> *) a.
(Traversable f, Applicative f) =>
TraversableVar f (f a) -> f (TraversableVar f a)
traverse :: (a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
$ctraverse :: forall (f :: * -> *) (f :: * -> *) a b.
(Traversable f, Applicative f) =>
(a -> f b) -> TraversableVar f a -> f (TraversableVar f b)
$cp2Traversable :: forall (f :: * -> *). Traversable f => Foldable (TraversableVar f)
$cp1Traversable :: forall (f :: * -> *). Traversable f => Functor (TraversableVar f)
Traversable)

newtype TraversableMetric f a = TraversableMetric (Metric a)
  deriving ((forall x. TraversableMetric f a -> Rep (TraversableMetric f a) x)
-> (forall x.
    Rep (TraversableMetric f a) x -> TraversableMetric f a)
-> Generic (TraversableMetric f a)
forall x. Rep (TraversableMetric f a) x -> TraversableMetric f a
forall x. TraversableMetric f a -> Rep (TraversableMetric f a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall k (f :: k) a x.
Rep (TraversableMetric f a) x -> TraversableMetric f a
forall k (f :: k) a x.
TraversableMetric f a -> Rep (TraversableMetric f a) x
$cto :: forall k (f :: k) a x.
Rep (TraversableMetric f a) x -> TraversableMetric f a
$cfrom :: forall k (f :: k) a x.
TraversableMetric f a -> Rep (TraversableMetric f a) x
Generic)

instance AdditiveGroup (Metric a) => AdditiveGroup (TraversableMetric f a)

instance VectorSpace (Metric a) => VectorSpace (TraversableMetric f a) where
  type Scalar (TraversableMetric f a) = Scalar (Metric a)

instance
  ( MetricTensor (Metric a),
    MtVector (Metric a) ~ Tang a,
    MtCovector (Metric a) ~ Grad a,
    Dual s (Tang a) (Grad a)
  ) =>
  MetricTensor (TraversableMetric f a)
  where
  type MtVector (TraversableMetric f a) = IntmapVector f (Tang a)
  type MtCovector (TraversableMetric f a) = IntmapVector f (Grad a)
  evalMetric :: TraversableMetric f a
-> MtCovector (TraversableMetric f a)
-> MtVector (TraversableMetric f a)
evalMetric (TraversableMetric Metric a
m) (IntmapVector da) = IntMap (Tang a) -> IntmapVector f (Tang a)
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector ((Grad a -> Tang a) -> IntMap (Grad a) -> IntMap (Tang a)
forall a b. (a -> b) -> IntMap a -> IntMap b
IntMap.map (Metric a -> MtCovector (Metric a) -> MtVector (Metric a)
forall g. MetricTensor g => g -> MtCovector g -> MtVector g
evalMetric Metric a
m) IntMap (Grad a)
da)

instance HasGrad a => HasGrad (TraversableVar f a) where
  type MScalar (TraversableVar f a) = MScalar a
  type Tang (TraversableVar f a) = IntmapVector f (Tang a)
  type Grad (TraversableVar f a) = IntmapVector f (Grad a)
  type Metric (TraversableVar f a) = TraversableMetric f a

-- | @IntmapVector@ serves as a gradient of 'TraversableVar'.
newtype IntmapVector f v = IntmapVector {IntmapVector f v -> IntMap v
unIntmapVector :: IntMap v}
  deriving (Int -> IntmapVector f v -> ShowS
[IntmapVector f v] -> ShowS
IntmapVector f v -> String
(Int -> IntmapVector f v -> ShowS)
-> (IntmapVector f v -> String)
-> ([IntmapVector f v] -> ShowS)
-> Show (IntmapVector f v)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall k (f :: k) v. Show v => Int -> IntmapVector f v -> ShowS
forall k (f :: k) v. Show v => [IntmapVector f v] -> ShowS
forall k (f :: k) v. Show v => IntmapVector f v -> String
showList :: [IntmapVector f v] -> ShowS
$cshowList :: forall k (f :: k) v. Show v => [IntmapVector f v] -> ShowS
show :: IntmapVector f v -> String
$cshow :: forall k (f :: k) v. Show v => IntmapVector f v -> String
showsPrec :: Int -> IntmapVector f v -> ShowS
$cshowsPrec :: forall k (f :: k) v. Show v => Int -> IntmapVector f v -> ShowS
Show)

instance AdditiveGroup a => AdditiveGroup (IntmapVector f a) where
  zeroV :: IntmapVector f a
zeroV = IntMap a -> IntmapVector f a
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector IntMap a
forall a. IntMap a
IntMap.empty
  negateV :: IntmapVector f a -> IntmapVector f a
negateV (IntmapVector IntMap a
v) = IntMap a -> IntmapVector f a
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector (a -> a
forall v. AdditiveGroup v => v -> v
negateV (a -> a) -> IntMap a -> IntMap a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IntMap a
v)
  IntmapVector IntMap a
u ^+^ :: IntmapVector f a -> IntmapVector f a -> IntmapVector f a
^+^ IntmapVector IntMap a
v = IntMap a -> IntmapVector f a
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector ((a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IntMap.unionWith a -> a -> a
forall v. AdditiveGroup v => v -> v -> v
(^+^) IntMap a
u IntMap a
v)
  IntmapVector IntMap a
u ^-^ :: IntmapVector f a -> IntmapVector f a -> IntmapVector f a
^-^ IntmapVector IntMap a
v = IntMap a -> IntmapVector f a
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector ((Int -> a -> a -> Maybe a)
-> (IntMap a -> IntMap a)
-> (IntMap a -> IntMap a)
-> IntMap a
-> IntMap a
-> IntMap a
forall a b c.
(Int -> a -> b -> Maybe c)
-> (IntMap a -> IntMap c)
-> (IntMap b -> IntMap c)
-> IntMap a
-> IntMap b
-> IntMap c
IntMap.mergeWithKey Int -> a -> a -> Maybe a
forall a p. AdditiveGroup a => p -> a -> a -> Maybe a
combine IntMap a -> IntMap a
forall a. a -> a
only1 IntMap a -> IntMap a
only2 IntMap a
u IntMap a
v)
    where
      combine :: p -> a -> a -> Maybe a
combine p
_key a
x a
y = a -> Maybe a
forall a. a -> Maybe a
Just (a
x a -> a -> a
forall v. AdditiveGroup v => v -> v -> v
^-^ a
y)
      only1 :: a -> a
only1 = a -> a
forall a. a -> a
id
      only2 :: IntMap a -> IntMap a
only2 = (a -> a) -> IntMap a -> IntMap a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall v. AdditiveGroup v => v -> v
negateV

instance VectorSpace v => VectorSpace (IntmapVector f v) where
  type Scalar (IntmapVector f v) = VectorSpace.Scalar v
  Scalar (IntmapVector f v)
a *^ :: Scalar (IntmapVector f v) -> IntmapVector f v -> IntmapVector f v
*^ (IntmapVector IntMap v
v) = IntMap v -> IntmapVector f v
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector ((v -> v) -> IntMap v -> IntMap v
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Scalar v
Scalar (IntmapVector f v)
a Scalar v -> v -> v
forall v. VectorSpace v => Scalar v -> v -> v
*^) IntMap v
v)

instance Dual s dv v => Dual s (IntmapVector f dv) (IntmapVector f v) where
  evalGrad :: IntmapVector f v -> IntmapVector f dv -> s
evalGrad (IntmapVector IntMap v
dv) (IntmapVector IntMap dv
v) = IntMap s -> s
forall (f :: * -> *) v. (Foldable f, AdditiveGroup v) => f v -> v
sumV (IntMap s -> s) -> IntMap s -> s
forall a b. (a -> b) -> a -> b
$ (v -> dv -> s) -> IntMap v -> IntMap dv -> IntMap s
forall a b c. (a -> b -> c) -> IntMap a -> IntMap b -> IntMap c
IntMap.intersectionWith v -> dv -> s
forall s v dv. Dual s v dv => dv -> v -> s
evalGrad IntMap v
dv IntMap dv
v

deriving via (IntMap v) instance Semigroup v => Semigroup (IntmapVector f v)

deriving via (IntMap v) instance Monoid v => Monoid (IntmapVector f v)

instance BasicVector v => BasicVector (IntmapVector f v) where
  type VecBuilder (IntmapVector f v) = IntmapVector f (VecBuilder v)
  sumBuilder :: VecBuilder (IntmapVector f v) -> IntmapVector f v
sumBuilder (IntmapVector v) = IntMap v -> IntmapVector f v
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector ((VecBuilder v -> v) -> IntMap (VecBuilder v) -> IntMap v
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VecBuilder v -> v
forall v. BasicVector v => VecBuilder v -> v
sumBuilder IntMap (VecBuilder v)
v)

imap ::
  forall t a b.
  Traversable t =>
  (Int -> a -> b) ->
  t a ->
  t b
imap :: (Int -> a -> b) -> t a -> t b
imap Int -> a -> b
mkBVar' t a
xs' = State Int (t b) -> Int -> t b
forall s a. State s a -> s -> a
evalState ((a -> StateT Int Identity b) -> t a -> State Int (t b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> StateT Int Identity b
getmkvar t a
xs') Int
0
  where
    getmkvar :: a -> State Int b
    getmkvar :: a -> StateT Int Identity b
getmkvar a
x = do
      Int
index <- StateT Int Identity Int
forall (m :: * -> *) s. Monad m => StateT s m s
get
      Int -> StateT Int Identity ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      b -> StateT Int Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> a -> b
mkBVar' Int
index a
x)

-- | Note that @splitTraversable@ won't be useful
-- for top level @BVar@, because the type @Grad (f a)@ is not exposed. 
splitTraversable ::
  forall f r a.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a
  ) =>
  BVar r (f a) ->
  f (BVar r a)
splitTraversable :: BVar r (f a) -> f (BVar r a)
splitTraversable (BVar f a
xs BackGrad r (Grad (f a))
dxs) = f (BVar r a)
vars
  where
    vars :: f (BVar r a)
    vars :: f (BVar r a)
vars = (Int -> a -> BVar r a) -> f a -> f (BVar r a)
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> BVar r a
mkBVar f a
xs
    mkBVar :: Int -> a -> BVar r a
    mkBVar :: Int -> a -> BVar r a
mkBVar Int
index a
x =
      let mkBuilder :: VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
          mkBuilder :: VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder VecBuilder (Grad a)
dx = IntMap (VecBuilder (Grad a))
-> IntmapVector f (VecBuilder (Grad a))
forall k (f :: k) v. IntMap v -> IntmapVector f v
IntmapVector (Int -> VecBuilder (Grad a) -> IntMap (VecBuilder (Grad a))
forall a. Int -> a -> IntMap a
IntMap.singleton Int
index VecBuilder (Grad a)
dx)
       in a -> BackGrad r (Grad a) -> BVar r a
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar a
x ((VecBuilder (Grad a) -> VecBuilder (IntmapVector f (Grad a)))
-> BackGrad r (IntmapVector f (Grad a)) -> BackGrad r (Grad a)
forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1_sparse VecBuilder (Grad a) -> VecBuilder (IntmapVector f (Grad a))
VecBuilder (Grad a) -> IntmapVector f (VecBuilder (Grad a))
mkBuilder BackGrad r (Grad (f a))
BackGrad r (IntmapVector f (Grad a))
dxs)

lift1_sparseT ::
  forall r a z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  BackGrad r a ->
  Term r (SparseVector z)
lift1_sparseT :: (VecBuilder z -> VecBuilder a)
-> BackGrad r a -> Term r (SparseVector z)
lift1_sparseT VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
f) = (SparseVector z -> VecBuilder a) -> Term r (SparseVector z)
forall x. (x -> VecBuilder a) -> Term r x
f (VecBuilder z -> VecBuilder a
fa (VecBuilder z -> VecBuilder a)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector)

-- Not exported, because it is untested and hardly useful.
_joinTraversable ::
  forall f r a.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    FullVector (Grad a)
  ) =>
  f (BVar r a) ->
  BVar r (f a)
_joinTraversable :: f (BVar r a) -> BVar r (f a)
_joinTraversable f (BVar r a)
x = f a -> BackGrad r (Grad (f a)) -> BVar r (f a)
forall r a. a -> BackGrad r (Grad a) -> BVar r a
BVar f a
values (BackGrad r (SparseVector (IntmapVector f (Grad a)))
-> BackGrad r (IntmapVector f (Grad a))
forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad BackGrad r (SparseVector (IntmapVector f (Grad a)))
node)
  where
    values :: f a
    values :: f a
values = BVar r a -> a
forall r a. BVar r a -> a
bvarValue (BVar r a -> a) -> f (BVar r a) -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BVar r a)
x
    grads :: f (BackGrad r (Grad a))
    grads :: f (BackGrad r (Grad a))
grads = BVar r a -> BackGrad r (Grad a)
forall r a. BVar r a -> BackGrad r (Grad a)
bvarGrad (BVar r a -> BackGrad r (Grad a))
-> f (BVar r a) -> f (BackGrad r (Grad a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (BVar r a)
x
    terms :: [Term r (SparseVector (IntmapVector f (Grad a)))]
    terms :: [Term r (SparseVector (IntmapVector f (Grad a)))]
terms = f (Term r (SparseVector (IntmapVector f (Grad a))))
-> [Term r (SparseVector (IntmapVector f (Grad a)))]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList ((Int
 -> BackGrad r (Grad a)
 -> Term r (SparseVector (IntmapVector f (Grad a))))
-> f (BackGrad r (Grad a))
-> f (Term r (SparseVector (IntmapVector f (Grad a))))
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm f (BackGrad r (Grad a))
grads)
    mkTerm :: Int -> BackGrad r (Grad a) -> Term r (SparseVector (IntmapVector f (Grad a)))
    mkTerm :: Int
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
mkTerm Int
index = (VecBuilder (IntmapVector f (Grad a)) -> VecBuilder (Grad a))
-> BackGrad r (Grad a)
-> Term r (SparseVector (IntmapVector f (Grad a)))
forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> BackGrad r a -> Term r (SparseVector z)
lift1_sparseT (Int -> IntmapVector f (VecBuilder (Grad a)) -> VecBuilder (Grad a)
forall x. Int -> IntmapVector f x -> x
lookupIntMap Int
index)
    lookupIntMap :: Int -> IntmapVector f x -> x
    lookupIntMap :: Int -> IntmapVector f x -> x
lookupIntMap Int
key (IntmapVector IntMap x
intmap) = case Int -> IntMap x -> Maybe x
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
key IntMap x
intmap of
      Maybe x
Nothing -> String -> x
forall a. HasCallStack => String -> a
error String
"Downhill BUG: Bad index in joinTraversable"
      Just x
value -> x
value
    node :: BackGrad r (SparseVector (IntmapVector f (Grad a)))
    node :: BackGrad r (SparseVector (IntmapVector f (Grad a)))
node = Expr r (SparseVector (IntmapVector f (Grad a)))
-> BackGrad r (SparseVector (IntmapVector f (Grad a)))
forall a v. Expr a v -> BackGrad a v
realNode ([Term r (SparseVector (IntmapVector f (Grad a)))]
-> Expr r (SparseVector (IntmapVector f (Grad a)))
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [Term r (SparseVector (IntmapVector f (Grad a)))]
terms)

-- | @backpropTraversable one combine fun@
--
-- @one@ is a value to be backpropagated. In case of @p@ being scalar, set @one@
-- to 1 to compute unscaled gradient.
--
-- @combine@ is given value of a parameter and its gradient to construct result,
-- just like @zipWith@.
--
-- @fun@ is the function to be differentiated.
backpropTraversable ::
  forall f a b p.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    HasGrad p,
    FullVector (Grad p)
  ) =>
  Grad p ->
  (a -> Grad a -> b) ->
  (forall r. f (BVar r a) -> BVar r p) ->
  f a ->
  f b
backpropTraversable :: Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one a -> Grad a -> b
combine forall r. f (BVar r a) -> BVar r p
fun f a
x = (Int -> a -> b) -> f a -> f b
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap Int -> a -> b
makeResult f a
x
  where
    splitX :: f (BVar (Grad (f a)) a)
    splitX :: f (BVar (Grad (f a)) a)
splitX = BVar (IntmapVector f (Grad a)) (f a)
-> f (BVar (IntmapVector f (Grad a)) a)
forall (f :: * -> *) r a.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a),
 HasGrad a) =>
BVar r (f a) -> f (BVar r a)
splitTraversable (f a -> BVar (Grad (f a)) (f a)
forall a. a -> BVar (Grad a) a
var f a
x)

    y :: BVar (Grad (f a)) p
    y :: BVar (Grad (f a)) p
y = f (BVar (IntmapVector f (Grad a)) a)
-> BVar (IntmapVector f (Grad a)) p
forall r. f (BVar r a) -> BVar r p
fun f (BVar (Grad (f a)) a)
f (BVar (IntmapVector f (Grad a)) a)
splitX

    grad :: IntMap (Grad a)
    IntmapVector IntMap (Grad a)
grad = BVar (IntmapVector f (Grad a)) p
-> Grad p -> IntmapVector f (Grad a)
forall r a.
(HasGrad a, FullVector (Grad a), BasicVector r) =>
BVar r a -> Grad a -> r
backprop BVar (Grad (f a)) p
BVar (IntmapVector f (Grad a)) p
y Grad p
one

    lookupGrad :: Int -> Grad a
lookupGrad Int
i = Grad a -> Maybe (Grad a) -> Grad a
forall a. a -> Maybe a -> a
fromMaybe Grad a
forall v. AdditiveGroup v => v
zeroV (Int -> IntMap (Grad a) -> Maybe (Grad a)
forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
i IntMap (Grad a)
grad)

    makeResult :: Int -> a -> b
    makeResult :: Int -> a -> b
makeResult Int
i a
x' = a -> Grad a -> b
combine a
x' (Int -> Grad a
lookupGrad Int
i)

{-# ANN backpropTraversable_GradOnly "HLint: ignore Use camelCase" #-}

-- | Like 'backpropTraversable', but returns gradient only.
backpropTraversable_GradOnly ::
  forall f a p.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    HasGrad p,
    FullVector (Grad p)
  ) =>
  Grad p ->
  (forall r. f (BVar r a) -> BVar r p) ->
  f a ->
  f (Grad a)
backpropTraversable_GradOnly :: Grad p -> (forall r. f (BVar r a) -> BVar r p) -> f a -> f (Grad a)
backpropTraversable_GradOnly Grad p
one = Grad p
-> (a -> Grad a -> Grad a)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f (Grad a)
forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p, FullVector (Grad p)) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one a -> Grad a -> Grad a
forall p p. p -> p -> p
gradOnly
  where
    gradOnly :: p -> p -> p
gradOnly p
_value p
grad = p
grad

-- | 'backpropTraversable' specialized to return a pair of value and gradient.
{-# ANN backpropTraversable_ValueAndGrad "HLint: ignore Use camelCase" #-}
backpropTraversable_ValueAndGrad ::
  forall f a p.
  ( Traversable f,
    Grad (f a) ~ Grad (TraversableVar f a),
    HasGrad a,
    HasGrad p,
    FullVector (Grad p)
  ) =>
  Grad p ->
  (forall r. f (BVar r a) -> BVar r p) ->
  f a ->
  f (a, Grad a)
backpropTraversable_ValueAndGrad :: Grad p
-> (forall r. f (BVar r a) -> BVar r p) -> f a -> f (a, Grad a)
backpropTraversable_ValueAndGrad Grad p
one = Grad p
-> (a -> Grad a -> (a, Grad a))
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f (a, Grad a)
forall (f :: * -> *) a b p.
(Traversable f, Grad (f a) ~ Grad (TraversableVar f a), HasGrad a,
 HasGrad p, FullVector (Grad p)) =>
Grad p
-> (a -> Grad a -> b)
-> (forall r. f (BVar r a) -> BVar r p)
-> f a
-> f b
backpropTraversable Grad p
one (,)