{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ExistentialQuantification #-}
module Synapse.NN.Layers.Layer
(
AbstractLayer (inputSize, outputSize, nParameters, getParameters, updateParameters, symbolicForward)
, forward
, Layer (Layer)
, LayerConfiguration
) where
import Synapse.Tensors (DType)
import Synapse.Tensors.Mat (Mat)
import Synapse.Autograd (Symbolic, SymbolIdentifier, Symbol(unSymbol), SymbolMat, constSymbol)
class AbstractLayer l where
inputSize :: l a -> Maybe Int
outputSize :: l a -> Maybe Int
nParameters :: l a -> Int
getParameters :: SymbolIdentifier -> l a -> [SymbolMat a]
updateParameters :: l a -> [Mat a] -> l a
symbolicForward :: (Symbolic a, Floating a, Ord a) => SymbolIdentifier -> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
forward :: (AbstractLayer l, Symbolic a, Floating a, Ord a) => Mat a -> l a -> Mat a
forward :: forall (l :: * -> *) a.
(AbstractLayer l, Symbolic a, Floating a, Ord a) =>
Mat a -> l a -> Mat a
forward Mat a
input = Symbol (Mat a) -> Mat a
forall a. Symbol a -> a
unSymbol (Symbol (Mat a) -> Mat a)
-> (l a -> Symbol (Mat a)) -> l a -> Mat a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Symbol (Mat a), Symbol (Mat a)) -> Symbol (Mat a)
forall a b. (a, b) -> a
fst ((Symbol (Mat a), Symbol (Mat a)) -> Symbol (Mat a))
-> (l a -> (Symbol (Mat a), Symbol (Mat a)))
-> l a
-> Symbol (Mat a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolIdentifier
-> Symbol (Mat a) -> l a -> (Symbol (Mat a), Symbol (Mat a))
forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
forall (l :: * -> *) a.
(AbstractLayer l, Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
forall a. Monoid a => a
mempty (Mat a -> Symbol (Mat a)
forall a. a -> Symbol a
constSymbol Mat a
input)
data Layer a = forall l. (AbstractLayer l) => Layer (l a)
type instance DType (Layer a) = a
instance AbstractLayer Layer where
inputSize :: forall a. Layer a -> Maybe Int
inputSize (Layer l a
l) = l a -> Maybe Int
forall a. l a -> Maybe Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Maybe Int
inputSize l a
l
outputSize :: forall a. Layer a -> Maybe Int
outputSize (Layer l a
l) = l a -> Maybe Int
forall a. l a -> Maybe Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Maybe Int
outputSize l a
l
nParameters :: forall a. Layer a -> Int
nParameters (Layer l a
l) = l a -> Int
forall a. l a -> Int
forall (l :: * -> *) a. AbstractLayer l => l a -> Int
nParameters l a
l
getParameters :: forall a. SymbolIdentifier -> Layer a -> [SymbolMat a]
getParameters SymbolIdentifier
prefix (Layer l a
l) = SymbolIdentifier -> l a -> [SymbolMat a]
forall a. SymbolIdentifier -> l a -> [SymbolMat a]
forall (l :: * -> *) a.
AbstractLayer l =>
SymbolIdentifier -> l a -> [SymbolMat a]
getParameters SymbolIdentifier
prefix l a
l
updateParameters :: forall a. Layer a -> [Mat a] -> Layer a
updateParameters (Layer l a
l) = l a -> Layer a
forall a (l :: * -> *). AbstractLayer l => l a -> Layer a
Layer (l a -> Layer a) -> ([Mat a] -> l a) -> [Mat a] -> Layer a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. l a -> [Mat a] -> l a
forall a. l a -> [Mat a] -> l a
forall (l :: * -> *) a. AbstractLayer l => l a -> [Mat a] -> l a
updateParameters l a
l
symbolicForward :: forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> Layer a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
prefix SymbolMat a
input (Layer l a
l) = SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
forall a.
(Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
forall (l :: * -> *) a.
(AbstractLayer l, Symbolic a, Floating a, Ord a) =>
SymbolIdentifier
-> SymbolMat a -> l a -> (SymbolMat a, SymbolMat a)
symbolicForward SymbolIdentifier
prefix SymbolMat a
input l a
l
type LayerConfiguration l
= Int
-> l