{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module Downhill.Linear.Backprop
  ( -- * Backpropagation
    backprop,

    -- * Graph
    buildGraph,
    --abstractBackprop,
  )
where

import Downhill.Internal.Graph.Graph
  ( SomeGraph (..),
    evalGraph,
    transposeGraph,
  )
import qualified Downhill.Internal.Graph.Graph as Graph
import Downhill.Internal.Graph.OpenGraph (recoverSharing)
import Downhill.Internal.Graph.Types (BackFun, flipBackFun)
import Downhill.Linear.BackGrad (BackGrad (..), castBackGrad)
import Downhill.Linear.Expr
  ( BasicVector (VecBuilder, identityBuilder),
    SparseVector (SparseVector, unSparseVector),
    Term,
  )
import GHC.IO.Unsafe (unsafePerformIO)

buildGraph ::
  forall a v.
  (BasicVector a, BasicVector v) =>
  [Term a v] ->
  IO (SomeGraph BackFun a v)
buildGraph :: forall a v.
(BasicVector a, BasicVector v) =>
[Term a v] -> IO (SomeGraph BackFun a v)
buildGraph [Term a v]
fidentityBuilder = do
  OpenGraph a v
og <- forall a z. BasicVector z => [Term a z] -> IO (OpenGraph a z)
recoverSharing [Term a v]
fidentityBuilder
  forall (m :: * -> *) a. Monad m => a -> m a
return (forall a v.
(BasicVector a, HasCallStack) =>
OpenGraph a v -> SomeGraph BackFun a v
Graph.unsafeFromOpenGraph OpenGraph a v
og)

abstractBackprop ::
  forall a u v.
  (BasicVector a, BasicVector u, BasicVector v) =>
  BackGrad a u ->
  (v -> VecBuilder u) ->
  v ->
  a
abstractBackprop :: forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u -> (v -> VecBuilder u) -> v -> a
abstractBackprop (BackGrad forall x. (x -> VecBuilder u) -> Term a x
f) v -> VecBuilder u
builder v
x =
  case forall a. IO a -> a
unsafePerformIO (forall a v.
(BasicVector a, BasicVector v) =>
[Term a v] -> IO (SomeGraph BackFun a v)
buildGraph [forall x. (x -> VecBuilder u) -> Term a x
f v -> VecBuilder u
builder]) of
    SomeGraph Graph s BackFun a v
g -> forall s x z. Graph s FwdFun z x -> z -> x
evalGraph (forall s (f :: * -> * -> *) (g :: * -> * -> *) a z.
IsNodeSet s =>
(forall u v. f u v -> g v u) -> Graph s f a z -> Graph s g z a
transposeGraph forall u v. BackFun u v -> FwdFun v u
flipBackFun Graph s BackFun a v
g) v
x

_backprop :: forall a v. (BasicVector a, BasicVector v) => BackGrad a v -> VecBuilder v -> a
_backprop :: forall a v.
(BasicVector a, BasicVector v) =>
BackGrad a v -> VecBuilder v -> a
_backprop BackGrad a v
dvar VecBuilder v
x =
  forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u -> (v -> VecBuilder u) -> v -> a
abstractBackprop @a @(SparseVector v) @(SparseVector v)
    BackGrad a (SparseVector v)
sparseDVar
    forall v. SparseVector v -> VecBuilder v
unSparseVector
    (forall v. VecBuilder v -> SparseVector v
SparseVector VecBuilder v
x)
  where
    sparseDVar :: BackGrad a (SparseVector v)
    sparseDVar :: BackGrad a (SparseVector v)
sparseDVar = forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad BackGrad a v
dvar

-- | Purity of this function depends on laws of arithmetic
-- and linearity law of 'Term'. If your addition is approximately
-- associative, then this function is approximately pure. Fair?
backprop :: forall a v. (BasicVector a, BasicVector v) => BackGrad a v -> v -> a
backprop :: forall a v.
(BasicVector a, BasicVector v) =>
BackGrad a v -> v -> a
backprop BackGrad a v
dvar = forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u -> (v -> VecBuilder u) -> v -> a
abstractBackprop BackGrad a v
dvar forall v. BasicVector v => v -> VecBuilder v
identityBuilder