module Grenade.Recurrent.Core.Runner (
trainRecurrent
, runRecurrent
, backPropagateRecurrent
) where
import Data.Singletons.Prelude
import Grenade.Core
import Grenade.Recurrent.Core.Layer
import Grenade.Recurrent.Core.Network
backPropagateRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers))
=> RecurrentNetwork layers shapes
-> RecurrentInputs layers
-> [(S (Head shapes), Maybe (S (Last shapes)))]
-> (RecurrentGradients layers, RecurrentInputs layers)
backPropagateRecurrent network recinputs examples =
let (tapes, _, guesses) = runRecurrentNetwork network recinputs inputs
backPropagations = zipWith makeError guesses targets
(gradients, input', _) = runRecurrentGradient network tapes 0 backPropagations
in (gradients, input')
where
inputs = fst <$> examples
targets = snd <$> examples
makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
makeError _ Nothing = 0
makeError y (Just t) = y t
trainRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers))
=> LearningParameters
-> RecurrentNetwork layers shapes
-> RecurrentInputs layers
-> [(S (Head shapes), Maybe (S (Last shapes)))]
-> (RecurrentNetwork layers shapes, RecurrentInputs layers)
trainRecurrent rate network recinputs examples =
let (gradients, recinputs') = backPropagateRecurrent network recinputs examples
newInputs = updateRecInputs rate recinputs recinputs'
newNetwork = applyRecurrentUpdate rate network gradients
in (newNetwork, newInputs)
updateRecInputs :: LearningParameters
-> RecurrentInputs sublayers
-> RecurrentInputs sublayers
-> RecurrentInputs sublayers
updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
= () :~~+> updateRecInputs l xs ys
updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
= (realToFrac (1 learningRate * learningRegulariser) * x realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
updateRecInputs _ RINil RINil
= RINil
runRecurrent :: RecurrentNetwork layers shapes
-> RecurrentInputs layers -> S (Head shapes)
-> (RecurrentInputs layers, S (Last shapes))
runRecurrent (layer :~~> n) (() :~~+> nr) !x
= let (_, ys) = runForwards layer x
(nr', o) = runRecurrent n nr ys
in (() :~~+> nr', o)
runRecurrent (layer :~@> n) (recin :~@+> nr) !x
= let (_, recin', y) = runRecurrentForwards layer recin x
(nr', o) = runRecurrent n nr y
in (recin' :~@+> nr', o)
runRecurrent RNil RINil !x
= (RINil, x)