{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
module Downhill.Linear.Backprop
(
backprop,
buildGraph,
)
where
import Downhill.Internal.Graph.Graph
( SomeGraph (..),
evalGraph,
transposeGraph,
)
import qualified Downhill.Internal.Graph.Graph as Graph
import Downhill.Internal.Graph.OpenGraph (recoverSharing)
import Downhill.Internal.Graph.Types (BackFun, flipBackFun)
import Downhill.Linear.BackGrad (BackGrad (..), castBackGrad)
import Downhill.Linear.Expr
( BasicVector (VecBuilder, identityBuilder),
SparseVector (SparseVector, unSparseVector),
Term,
)
import GHC.IO.Unsafe (unsafePerformIO)
buildGraph ::
forall a v.
(BasicVector a, BasicVector v) =>
[Term a v] ->
IO (SomeGraph BackFun a v)
buildGraph :: forall a v.
(BasicVector a, BasicVector v) =>
[Term a v] -> IO (SomeGraph BackFun a v)
buildGraph [Term a v]
fidentityBuilder = do
OpenGraph a v
og <- forall a z. BasicVector z => [Term a z] -> IO (OpenGraph a z)
recoverSharing [Term a v]
fidentityBuilder
forall (m :: * -> *) a. Monad m => a -> m a
return (forall a v.
(BasicVector a, HasCallStack) =>
OpenGraph a v -> SomeGraph BackFun a v
Graph.unsafeFromOpenGraph OpenGraph a v
og)
abstractBackprop ::
forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u ->
(v -> VecBuilder u) ->
v ->
a
abstractBackprop :: forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u -> (v -> VecBuilder u) -> v -> a
abstractBackprop (BackGrad forall x. (x -> VecBuilder u) -> Term a x
f) v -> VecBuilder u
builder v
x =
case forall a. IO a -> a
unsafePerformIO (forall a v.
(BasicVector a, BasicVector v) =>
[Term a v] -> IO (SomeGraph BackFun a v)
buildGraph [forall x. (x -> VecBuilder u) -> Term a x
f v -> VecBuilder u
builder]) of
SomeGraph Graph s BackFun a v
g -> forall s x z. Graph s FwdFun z x -> z -> x
evalGraph (forall s (f :: * -> * -> *) (g :: * -> * -> *) a z.
IsNodeSet s =>
(forall u v. f u v -> g v u) -> Graph s f a z -> Graph s g z a
transposeGraph forall u v. BackFun u v -> FwdFun v u
flipBackFun Graph s BackFun a v
g) v
x
_backprop :: forall a v. (BasicVector a, BasicVector v) => BackGrad a v -> VecBuilder v -> a
_backprop :: forall a v.
(BasicVector a, BasicVector v) =>
BackGrad a v -> VecBuilder v -> a
_backprop BackGrad a v
dvar VecBuilder v
x =
forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u -> (v -> VecBuilder u) -> v -> a
abstractBackprop @a @(SparseVector v) @(SparseVector v)
BackGrad a (SparseVector v)
sparseDVar
forall v. SparseVector v -> VecBuilder v
unSparseVector
(forall v. VecBuilder v -> SparseVector v
SparseVector VecBuilder v
x)
where
sparseDVar :: BackGrad a (SparseVector v)
sparseDVar :: BackGrad a (SparseVector v)
sparseDVar = forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad BackGrad a v
dvar
backprop :: forall a v. (BasicVector a, BasicVector v) => BackGrad a v -> v -> a
backprop :: forall a v.
(BasicVector a, BasicVector v) =>
BackGrad a v -> v -> a
backprop BackGrad a v
dvar = forall a u v.
(BasicVector a, BasicVector u, BasicVector v) =>
BackGrad a u -> (v -> VecBuilder u) -> v -> a
abstractBackprop BackGrad a v
dvar forall v. BasicVector v => v -> VecBuilder v
identityBuilder