{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}

module Downhill.Linear.BackGrad
  ( BackGrad (..),
    realNode,
    inlineNode,
    sparseNode,
    castBackGrad,
  )
where

import Data.VectorSpace
  ( AdditiveGroup (..),
    Scalar,
    VectorSpace (..),
  )
import Downhill.Linear.Expr
  ( BasicVector (VecBuilder),
    Expr (ExprSum),
    FullVector (identityBuilder, negateBuilder, scaleBuilder),
    Term (Term), SparseVector (unSparseVector),
  )

-- | Linear expression, made for backpropagation.
-- It is similar to @'Expr' 'BackFun'@, but has a more flexible form.
newtype BackGrad a v
  = BackGrad
      ( forall x.
        (x -> VecBuilder v) ->
        Term a x
      )

-- | Creates a @BackGrad@ that is backed by a real node. Gradient of type @v@ will be computed and stored
--   in a graph for this node.
{-# ANN module "HLint: ignore Avoid lambda using `infix`" #-}

realNode :: Expr a v -> BackGrad a v
realNode :: Expr a v -> BackGrad a v
realNode Expr a v
x = (forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad (\x -> VecBuilder v
f -> (x -> VecBuilder v) -> Expr a v -> Term a x
forall v u a. (v -> VecBuilder u) -> Expr a u -> Term a v
Term x -> VecBuilder v
f Expr a v
x)

-- | @inlineNode f x@ will apply function @f@ to variable @x@ without creating a node. All of the gradients
-- coming to this expression will be forwarded to the parents of @x@. However, if this expression is used
-- more than once, @f@ will be evaluated multiple times, too. It is intended to be used for @newtype@ wrappers.
-- @inlineNode f x@ also doesn't prevent
-- compiler to inline and optimize @x@
inlineNode ::
  forall r u v.
  (VecBuilder v -> VecBuilder u) ->
  BackGrad r u ->
  BackGrad r v
inlineNode :: (VecBuilder v -> VecBuilder u) -> BackGrad r u -> BackGrad r v
inlineNode VecBuilder v -> VecBuilder u
f (BackGrad forall x. (x -> VecBuilder u) -> Term r x
g) = (forall x. (x -> VecBuilder v) -> Term r x) -> BackGrad r v
forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
go
  where
    go :: forall x. (x -> VecBuilder v) -> Term r x
    go :: (x -> VecBuilder v) -> Term r x
go x -> VecBuilder v
h = (x -> VecBuilder u) -> Term r x
forall x. (x -> VecBuilder u) -> Term r x
g (VecBuilder v -> VecBuilder u
f (VecBuilder v -> VecBuilder u)
-> (x -> VecBuilder v) -> x -> VecBuilder u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> VecBuilder v
h)

sparseNode ::
  forall r a z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  BackGrad r a ->
  BackGrad r z
sparseNode :: (VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
sparseNode VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
x) = BackGrad r (SparseVector z) -> BackGrad r z
forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (Expr r (SparseVector z) -> BackGrad r (SparseVector z)
forall a v. Expr a v -> BackGrad a v
realNode Expr r (SparseVector z)
node)
  where
    fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa (VecBuilder z -> VecBuilder a)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
    node :: Expr r (SparseVector z)
    node :: Expr r (SparseVector z)
node = [Term r (SparseVector z)] -> Expr r (SparseVector z)
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(SparseVector z -> VecBuilder a) -> Term r (SparseVector z)
forall x. (x -> VecBuilder a) -> Term r x
x SparseVector z -> VecBuilder a
fa']

-- | @BackGrad@ doesn't track the type of the node. Type of @BackGrad@ can be changed freely
-- as long as @VecBuilder@ stays the same.
castBackGrad ::
  forall r v z.
  VecBuilder z ~ VecBuilder v =>
  BackGrad r v ->
  BackGrad r z
castBackGrad :: BackGrad r v -> BackGrad r z
castBackGrad (BackGrad forall x. (x -> VecBuilder v) -> Term r x
g) = (forall x. (x -> VecBuilder z) -> Term r x) -> BackGrad r z
forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
forall x. (x -> VecBuilder z) -> Term r x
g

instance (FullVector v) => AdditiveGroup (BackGrad r v) where
  zeroV :: BackGrad r v
zeroV = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [])
  negateV :: BackGrad r v -> BackGrad r v
negateV (BackGrad forall x. (x -> VecBuilder v) -> Term r x
x) = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
x v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
negateBuilder])
  BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^+^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^+^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
x v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
identityBuilder, (v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
y v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
identityBuilder])
  BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^-^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^-^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
x v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
identityBuilder, (v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
y v -> VecBuilder v
forall v. FullVector v => v -> VecBuilder v
negateBuilder])

instance FullVector v => VectorSpace (BackGrad r v) where
  type Scalar (BackGrad r v) = Scalar v
  Scalar (BackGrad r v)
a *^ :: Scalar (BackGrad r v) -> BackGrad r v -> BackGrad r v
*^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
v = Expr r v -> BackGrad r v
forall a v. Expr a v -> BackGrad a v
realNode ([Term r v] -> Expr r v
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(v -> VecBuilder v) -> Term r v
forall x. (x -> VecBuilder v) -> Term r x
v (Scalar v -> v -> VecBuilder v
forall v. FullVector v => Scalar v -> v -> VecBuilder v
scaleBuilder Scalar v
Scalar (BackGrad r v)
a)])