{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
module Downhill.Linear.BackGrad
( BackGrad (..),
realNode,
inlineNode,
sparseNode,
castBackGrad,
)
where
import Data.VectorSpace
( AdditiveGroup (..),
Scalar,
VectorSpace (..),
)
import Downhill.Linear.Expr
( BasicVector (VecBuilder),
Expr (ExprSum),
FullVector (identityBuilder, negateBuilder, scaleBuilder),
Term (Term), SparseVector (unSparseVector),
)
newtype BackGrad a v
= BackGrad
( forall x.
(x -> VecBuilder v) ->
Term a x
)
{-# ANN module "HLint: ignore Avoid lambda using `infix`" #-}
realNode :: Expr a v -> BackGrad a v
realNode :: Expr a v -> BackGrad a v
realNode Expr a v
x = (forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad (\x -> VecBuilder v
f -> (x -> VecBuilder v) -> Expr a v -> Term a x
forall v u a. (v -> VecBuilder u) -> Expr a u -> Term a v
Term x -> VecBuilder v
f Expr a v
x)
inlineNode ::
forall r u v.
(VecBuilder v -> VecBuilder u) ->
BackGrad r u ->
BackGrad r v
inlineNode :: (VecBuilder v -> VecBuilder u) -> BackGrad r u -> BackGrad r v
inlineNode VecBuilder v -> VecBuilder u
f (BackGrad forall x. (x -> VecBuilder u) -> Term r x
g) = (forall x. (x -> VecBuilder v) -> Term r x) -> BackGrad r v
forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
go
where
go :: forall x. (x -> VecBuilder v) -> Term r x
go :: (x -> VecBuilder v) -> Term r x
go x -> VecBuilder v
h = (x -> VecBuilder u) -> Term r x
forall x. (x -> VecBuilder u) -> Term r x
g (VecBuilder v -> VecBuilder u
f (VecBuilder v -> VecBuilder u)
-> (x -> VecBuilder v) -> x -> VecBuilder u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> VecBuilder v
h)
sparseNode ::
forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) ->
BackGrad r a ->
BackGrad r z
sparseNode :: (VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
sparseNode VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
x) = BackGrad r (SparseVector z) -> BackGrad r z
forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (Expr r (SparseVector z) -> BackGrad r (SparseVector z)
forall a v. Expr a v -> BackGrad a v
realNode Expr r (SparseVector z)
node)
where
fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa (VecBuilder z -> VecBuilder a)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
node :: Expr r (SparseVector z)
node :: Expr r (SparseVector z)
node = [Term r (SparseVector z)] -> Expr r (SparseVector z)
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(SparseVector z -> VecBuilder a) -> Term r (SparseVector z)
forall x. (x -> VecBuilder a) -> Term r x
x SparseVector z -> VecBuilder a
fa']
castBackGrad ::
forall r v z.
VecBuilder z ~ VecBuilder v =>
BackGrad r v ->
BackGrad r z
castBackGrad :: BackGrad r v -> BackGrad r z
castBackGrad (BackGrad forall x. (x -> VecBuilder v) -> Term r x
g) = (forall x. (x -> VecBuilder z) -> Term r x) -> BackGrad r z
forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
forall x. (x -> VecBuilder z) -> Term r x
g
instance (FullVector v) => AdditiveGroup (BackGrad r v) where
zeroV :: BackGrad r v
zeroV = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [])
negateV :: BackGrad r v -> BackGrad r v
negateV (BackGrad forall x. (x -> VecBuilder v) -> Term r x
x) = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
x v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
negateBuilder])
BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^+^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^+^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
x v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
identityBuilder, (v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
y v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
identityBuilder])
BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^-^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^-^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
x v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
identityBuilder, (v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
y v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
negateBuilder])
instance FullVector v => VectorSpace (BackGrad r v) where
type Scalar (BackGrad r v) = Scalar v
Scalar (BackGrad r v)
a *^ :: Scalar (BackGrad r v) -> BackGrad r v -> BackGrad r v
*^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
v = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
v (Scalar v -> v -> VecBuilder v
forall v. FullVector v => Scalar v -> v -> VecBuilder v
scaleBuilder Scalar v
Scalar (BackGrad r v)
a)])