module Grenade.Recurrent.Layers.LSTM (
LSTM (..)
, LSTMWeights (..)
, randomLSTM
) where
import Control.Monad.Random ( MonadRandom, getRandom )
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import qualified Numeric.LinearAlgebra as LA
import Numeric.LinearAlgebra.Static
import Grenade.Core
import Grenade.Recurrent.Core
import Grenade.Layers.Internal.Update
data LSTM :: Nat -> Nat -> * where
LSTM :: ( KnownNat input
, KnownNat output
) => !(LSTMWeights input output)
-> !(LSTMWeights input output)
-> LSTM input output
data LSTMWeights :: Nat -> Nat -> * where
LSTMWeights :: ( KnownNat input
, KnownNat output
) => {
lstmWf :: !(L output input)
, lstmUf :: !(L output output)
, lstmBf :: !(R output)
, lstmWi :: !(L output input)
, lstmUi :: !(L output output)
, lstmBi :: !(R output)
, lstmWo :: !(L output input)
, lstmUo :: !(L output output)
, lstmBo :: !(R output)
, lstmWc :: !(L output input)
, lstmBc :: !(R output)
} -> LSTMWeights input output
instance Show (LSTM i o) where
show LSTM {} = "LSTM"
instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
type Gradient (LSTM i o) = (LSTMWeights i o)
runUpdate LearningParameters {..} (LSTM w m) g =
let (wf, wf') = u lstmWf w m g
(uf, uf') = u lstmUf w m g
(bf, bf') = v lstmBf w m g
(wi, wi') = u lstmWi w m g
(ui, ui') = u lstmUi w m g
(bi, bi') = v lstmBi w m g
(wo, wo') = u lstmWo w m g
(uo, uo') = u lstmUo w m g
(bo, bo') = v lstmBo w m g
(wc, wc') = u lstmWc w m g
(bc, bc') = v lstmBc w m g
in LSTM (LSTMWeights wf uf bf wi ui bi wo uo bo wc bc) (LSTMWeights wf' uf' bf' wi' ui' bi' wo' uo' bo' wc' bc')
where
u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
u e (e -> weights) (e -> momentum) (e -> gradient) =
decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
v e (e -> weights) (e -> momentum) (e -> gradient) =
decendVector learningRate learningMomentum learningRegulariser weights gradient momentum
createRandom = randomLSTM
instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
type RecurrentShape (LSTM i o) = 'D1 o
instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i))
runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
let
f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
c_x = tanh $ lstmBc + lstmWc #> input
c_t = f_t * cell + i_t * c_x
h_t = o_t * c_t
in ((S1D cell, S1D input), S1D c_t, S1D h_t)
runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') =
let
f_s = lstmBf + lstmWf #> input + lstmUf #> cell
f_t = sigmoid f_s
i_s = lstmBi + lstmWi #> input + lstmUi #> cell
i_t = sigmoid i_s
o_s = lstmBo + lstmWo #> input + lstmUo #> cell
o_t = sigmoid o_s
c_s = lstmBc + lstmWc #> input
c_x = tanh c_s
c_t = f_t * cell + i_t * c_x
c_t' = h_t' * o_t + cellGrad
f_t' = c_t' * cell
f_s' = sigmoid' f_s * f_t'
o_t' = h_t' * c_t
o_s' = sigmoid' o_s * o_t'
i_t' = c_t' * c_x
i_s' = sigmoid' i_s * i_t'
c_x' = c_t' * i_t
c_s' = tanh' c_s * c_x'
cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
lstmWf' = f_s' `outer` input
lstmWi' = i_s' `outer` input
lstmWo' = o_s' `outer` input
lstmWc' = c_s' `outer` input
lstmUf' = f_s' `outer` cell
lstmUi' = i_s' `outer` cell
lstmUo' = o_s' `outer` cell
lstmBf' = f_s'
lstmBi' = i_s'
lstmBo' = o_s'
lstmBc' = c_s'
gradients = LSTMWeights lstmWf' lstmUf' lstmBf' lstmWi' lstmUi' lstmBi' lstmWo' lstmUo' lstmBo' lstmWc' lstmBc'
in (gradients, S1D cell', S1D input')
randomLSTM :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o)
=> m (LSTM i o)
randomLSTM = do
let w = (\s -> uniformSample s (1) 1 ) <$> getRandom
u = (\s -> uniformSample s (1) 1 ) <$> getRandom
v = (\s -> randomVector s Uniform * 2 1) <$> getRandom
w0 = konst 0
u0 = konst 0
v0 = konst 0
LSTM <$> (LSTMWeights <$> w <*> u <*> pure (konst 1) <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
<*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
sigmoid :: Floating a => a -> a
sigmoid x = 1 / (1 + exp (x))
sigmoid' :: Floating a => a -> a
sigmoid' x = logix * (1 logix)
where
logix = sigmoid x
tanh' :: (Floating a) => a -> a
tanh' t = 1 s ^ (2 :: Int) where s = tanh t
instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
put (LSTM LSTMWeights {..} _) = do
u lstmWf
u lstmUf
v lstmBf
u lstmWi
u lstmUi
v lstmBi
u lstmWo
u lstmUo
v lstmBo
u lstmWc
v lstmBc
where
u :: forall a b. (KnownNat a, KnownNat b) => Putter (L b a)
u = putListOf put . LA.toList . LA.flatten . extract
v :: forall a. (KnownNat a) => Putter (R a)
v = putListOf put . LA.toList . extract
get = do
lstmWf <- u
lstmUf <- u
lstmBf <- v
lstmWi <- u
lstmUi <- u
lstmBi <- v
lstmWo <- u
lstmUo <- u
lstmBo <- v
lstmWc <- u
lstmBc <- v
return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
where
u :: forall a b. (KnownNat a, KnownNat b) => Get (L b a)
u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
in maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
v :: forall a. (KnownNat a) => Get (R a)
v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get
w0 = konst 0
u0 = konst 0
v0 = konst 0