{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeOperators #-}
module Downhill.Linear.BackGrad
( BackGrad (..),
realNode,
inlineNode,
sparseNode,
castBackGrad,
)
where
import Data.VectorSpace
( AdditiveGroup (..),
Scalar,
VectorSpace (..),
)
import Downhill.Linear.Expr
( BasicVector (VecBuilder, identityBuilder),
Expr (ExprSum),
Term (Term), SparseVector (unSparseVector),
)
newtype BackGrad a v
= BackGrad
( forall x.
(x -> VecBuilder v) ->
Term a x
)
{-# ANN module "HLint: ignore Avoid lambda using `infix`" #-}
realNode :: Expr a v -> BackGrad a v
realNode :: forall a v. Expr a v -> BackGrad a v
realNode Expr a v
x = forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad (\x -> VecBuilder v
f -> forall v u a. (v -> VecBuilder u) -> Expr a u -> Term a v
Term x -> VecBuilder v
f Expr a v
x)
inlineNode ::
forall r u v.
(VecBuilder v -> VecBuilder u) ->
BackGrad r u ->
BackGrad r v
inlineNode :: forall r u v.
(VecBuilder v -> VecBuilder u) -> BackGrad r u -> BackGrad r v
inlineNode VecBuilder v -> VecBuilder u
f (BackGrad forall x. (x -> VecBuilder u) -> Term r x
g) = forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
go
where
go :: forall x. (x -> VecBuilder v) -> Term r x
go :: forall x. (x -> VecBuilder v) -> Term r x
go x -> VecBuilder v
h = forall x. (x -> VecBuilder u) -> Term r x
g (VecBuilder v -> VecBuilder u
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> VecBuilder v
h)
sparseNode ::
forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) ->
BackGrad r a ->
BackGrad r z
sparseNode :: forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
sparseNode VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
x) = forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (forall a v. Expr a v -> BackGrad a v
realNode Expr r (SparseVector z)
node)
where
fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector
node :: Expr r (SparseVector z)
node :: Expr r (SparseVector z)
node = forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder a) -> Term r x
x SparseVector z -> VecBuilder a
fa']
castBackGrad ::
forall r v z.
VecBuilder z ~ VecBuilder v =>
BackGrad r v ->
BackGrad r z
castBackGrad :: forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (BackGrad forall x. (x -> VecBuilder v) -> Term r x
g) = forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
g
instance (BasicVector v, AdditiveGroup v) => AdditiveGroup (BackGrad r v) where
zeroV :: BackGrad r v
zeroV = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [])
negateV :: BackGrad r v -> BackGrad r v
negateV (BackGrad forall x. (x -> VecBuilder v) -> Term r x
x) = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
x (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. AdditiveGroup v => v -> v
negateV)])
BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^+^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^+^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
x forall v. BasicVector v => v -> VecBuilder v
identityBuilder, forall x. (x -> VecBuilder v) -> Term r x
y forall v. BasicVector v => v -> VecBuilder v
identityBuilder])
BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^-^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^-^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
x forall v. BasicVector v => v -> VecBuilder v
identityBuilder, forall x. (x -> VecBuilder v) -> Term r x
y (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. AdditiveGroup v => v -> v
negateV)])
instance (BasicVector v, VectorSpace v) => VectorSpace (BackGrad r v) where
type Scalar (BackGrad r v) = Scalar v
Scalar (BackGrad r v)
a *^ :: Scalar (BackGrad r v) -> BackGrad r v -> BackGrad r v
*^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
v = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
v (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scalar (BackGrad r v)
aforall v. VectorSpace v => Scalar v -> v -> v
*^))])