{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Torch.NN.Recurrent.Cell.Elman where
import GHC.Generics
import Torch
data ElmanSpec = ElmanSpec
{ ElmanSpec -> Int
inputSize :: Int,
ElmanSpec -> Int
hiddenSize :: Int
}
deriving (ElmanSpec -> ElmanSpec -> Bool
(ElmanSpec -> ElmanSpec -> Bool)
-> (ElmanSpec -> ElmanSpec -> Bool) -> Eq ElmanSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ElmanSpec -> ElmanSpec -> Bool
== :: ElmanSpec -> ElmanSpec -> Bool
$c/= :: ElmanSpec -> ElmanSpec -> Bool
/= :: ElmanSpec -> ElmanSpec -> Bool
Eq, Int -> ElmanSpec -> ShowS
[ElmanSpec] -> ShowS
ElmanSpec -> String
(Int -> ElmanSpec -> ShowS)
-> (ElmanSpec -> String)
-> ([ElmanSpec] -> ShowS)
-> Show ElmanSpec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ElmanSpec -> ShowS
showsPrec :: Int -> ElmanSpec -> ShowS
$cshow :: ElmanSpec -> String
show :: ElmanSpec -> String
$cshowList :: [ElmanSpec] -> ShowS
showList :: [ElmanSpec] -> ShowS
Show)
data ElmanCell = ElmanCell
{ ElmanCell -> Parameter
weightsIH :: Parameter,
ElmanCell -> Parameter
weightsHH :: Parameter,
ElmanCell -> Parameter
biasIH :: Parameter,
ElmanCell -> Parameter
biasHH :: Parameter
}
deriving ((forall x. ElmanCell -> Rep ElmanCell x)
-> (forall x. Rep ElmanCell x -> ElmanCell) -> Generic ElmanCell
forall x. Rep ElmanCell x -> ElmanCell
forall x. ElmanCell -> Rep ElmanCell x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ElmanCell -> Rep ElmanCell x
from :: forall x. ElmanCell -> Rep ElmanCell x
$cto :: forall x. Rep ElmanCell x -> ElmanCell
to :: forall x. Rep ElmanCell x -> ElmanCell
Generic, Int -> ElmanCell -> ShowS
[ElmanCell] -> ShowS
ElmanCell -> String
(Int -> ElmanCell -> ShowS)
-> (ElmanCell -> String)
-> ([ElmanCell] -> ShowS)
-> Show ElmanCell
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ElmanCell -> ShowS
showsPrec :: Int -> ElmanCell -> ShowS
$cshow :: ElmanCell -> String
show :: ElmanCell -> String
$cshowList :: [ElmanCell] -> ShowS
showList :: [ElmanCell] -> ShowS
Show)
elmanCellForward ::
ElmanCell ->
Tensor ->
Tensor ->
Tensor
elmanCellForward :: ElmanCell -> Tensor -> Tensor -> Tensor
elmanCellForward ElmanCell {Parameter
weightsIH :: ElmanCell -> Parameter
weightsHH :: ElmanCell -> Parameter
biasIH :: ElmanCell -> Parameter
biasHH :: ElmanCell -> Parameter
weightsIH :: Parameter
weightsHH :: Parameter
biasIH :: Parameter
biasHH :: Parameter
..} Tensor
input Tensor
hidden =
Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
rnnReluCell Tensor
weightsIH' Tensor
weightsHH' Tensor
biasIH' Tensor
biasHH' Tensor
hidden Tensor
input
where
weightsIH' :: Tensor
weightsIH' = Parameter -> Tensor
toDependent Parameter
weightsIH
weightsHH' :: Tensor
weightsHH' = Parameter -> Tensor
toDependent Parameter
weightsHH
biasIH' :: Tensor
biasIH' = Parameter -> Tensor
toDependent Parameter
biasIH
biasHH' :: Tensor
biasHH' = Parameter -> Tensor
toDependent Parameter
biasIH
instance Parameterized ElmanCell
instance Randomizable ElmanSpec ElmanCell where
sample :: ElmanSpec -> IO ElmanCell
sample ElmanSpec {Int
inputSize :: ElmanSpec -> Int
hiddenSize :: ElmanSpec -> Int
inputSize :: Int
hiddenSize :: Int
..} = do
Parameter
weightsIH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
inputSize]
Parameter
weightsHH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
hiddenSize]
Parameter
biasIH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
Parameter
biasHH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
ElmanCell -> IO ElmanCell
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ElmanCell -> IO ElmanCell) -> ElmanCell -> IO ElmanCell
forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Parameter -> Parameter -> ElmanCell
ElmanCell Parameter
weightsIH Parameter
weightsHH Parameter
biasIH Parameter
biasHH