{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeOperators #-}

module Downhill.Linear.BackGrad
  ( BackGrad (..),
    realNode,
    inlineNode,
    sparseNode,
    castBackGrad,
  )
where

import Data.VectorSpace
  ( AdditiveGroup (..),
    Scalar,
    VectorSpace (..),
  )
import Downhill.Linear.Expr
  ( BasicVector (VecBuilder, identityBuilder),
    Expr (ExprSum),
    Term (Term), SparseVector (unSparseVector),
  )

-- | Linear expression, made for backpropagation.
-- It is similar to @'Expr' 'BackFun'@, but has a more flexible form.
newtype BackGrad a v
  = BackGrad
      ( forall x.
        (x -> VecBuilder v) ->
        Term a x
      )

-- | Creates a @BackGrad@ that is backed by a real node. Gradient of type @v@ will be computed and stored
--   in a graph for this node.
{-# ANN module "HLint: ignore Avoid lambda using `infix`" #-}

realNode :: Expr a v -> BackGrad a v
realNode :: forall a v. Expr a v -> BackGrad a v
realNode Expr a v
x = forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad (\x -> VecBuilder v
f -> forall v u a. (v -> VecBuilder u) -> Expr a u -> Term a v
Term x -> VecBuilder v
f Expr a v
x)

-- | @inlineNode f x@ will apply function @f@ to variable @x@ without creating a node. All of the gradients
-- coming to this expression will be forwarded to the parents of @x@. However, if this expression is used
-- more than once, @f@ will be evaluated multiple times, too. It is intended to be used for @newtype@ wrappers.
-- @inlineNode f x@ also doesn't prevent
-- compiler to inline and optimize @x@
inlineNode ::
  forall r u v.
  (VecBuilder v -> VecBuilder u) ->
  BackGrad r u ->
  BackGrad r v
inlineNode :: forall r u v.
(VecBuilder v -> VecBuilder u) -> BackGrad r u -> BackGrad r v
inlineNode VecBuilder v -> VecBuilder u
f (BackGrad forall x. (x -> VecBuilder u) -> Term r x
g) = forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
go
  where
    go :: forall x. (x -> VecBuilder v) -> Term r x
    go :: forall x. (x -> VecBuilder v) -> Term r x
go x -> VecBuilder v
h = forall x. (x -> VecBuilder u) -> Term r x
g (VecBuilder v -> VecBuilder u
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. x -> VecBuilder v
h)

sparseNode ::
  forall r a z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  BackGrad r a ->
  BackGrad r z
sparseNode :: forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
sparseNode VecBuilder z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
x) = forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (forall a v. Expr a v -> BackGrad a v
realNode Expr r (SparseVector z)
node)
  where
    fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector
    node :: Expr r (SparseVector z)
    node :: Expr r (SparseVector z)
node = forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder a) -> Term r x
x SparseVector z -> VecBuilder a
fa']

-- | @BackGrad@ doesn't track the type of the node. Type of @BackGrad@ can be changed freely
-- as long as @VecBuilder@ stays the same.
castBackGrad ::
  forall r v z.
  VecBuilder z ~ VecBuilder v =>
  BackGrad r v ->
  BackGrad r z
castBackGrad :: forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (BackGrad forall x. (x -> VecBuilder v) -> Term r x
g) = forall a v.
(forall x. (x -> VecBuilder v) -> Term a x) -> BackGrad a v
BackGrad forall x. (x -> VecBuilder v) -> Term r x
g

instance (BasicVector v, AdditiveGroup v) => AdditiveGroup (BackGrad r v) where
  zeroV :: BackGrad r v
zeroV = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [])
  negateV :: BackGrad r v -> BackGrad r v
negateV (BackGrad forall x. (x -> VecBuilder v) -> Term r x
x) = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
x (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. AdditiveGroup v => v -> v
negateV)])
  BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^+^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^+^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
x forall v. BasicVector v => v -> VecBuilder v
identityBuilder, forall x. (x -> VecBuilder v) -> Term r x
y forall v. BasicVector v => v -> VecBuilder v
identityBuilder])
  BackGrad forall x. (x -> VecBuilder v) -> Term r x
x ^-^ :: BackGrad r v -> BackGrad r v -> BackGrad r v
^-^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
y = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
x forall v. BasicVector v => v -> VecBuilder v
identityBuilder, forall x. (x -> VecBuilder v) -> Term r x
y (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. AdditiveGroup v => v -> v
negateV)])

instance (BasicVector v, VectorSpace v) => VectorSpace (BackGrad r v) where
  type Scalar (BackGrad r v) = Scalar v
  Scalar (BackGrad r v)
a *^ :: Scalar (BackGrad r v) -> BackGrad r v -> BackGrad r v
*^ BackGrad forall x. (x -> VecBuilder v) -> Term r x
v = forall a v. Expr a v -> BackGrad a v
realNode (forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder v) -> Term r x
v (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Scalar (BackGrad r v)
aforall v. VectorSpace v => Scalar v -> v -> v
*^))])