{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.NN.Recurrent.Cell.Elman where

import GHC.Generics
import Torch

data ElmanSpec = ElmanSpec
  { ElmanSpec -> Int
inputSize :: Int,
    ElmanSpec -> Int
hiddenSize :: Int
  }
  deriving (ElmanSpec -> ElmanSpec -> Bool
(ElmanSpec -> ElmanSpec -> Bool)
-> (ElmanSpec -> ElmanSpec -> Bool) -> Eq ElmanSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ElmanSpec -> ElmanSpec -> Bool
== :: ElmanSpec -> ElmanSpec -> Bool
$c/= :: ElmanSpec -> ElmanSpec -> Bool
/= :: ElmanSpec -> ElmanSpec -> Bool
Eq, Int -> ElmanSpec -> ShowS
[ElmanSpec] -> ShowS
ElmanSpec -> String
(Int -> ElmanSpec -> ShowS)
-> (ElmanSpec -> String)
-> ([ElmanSpec] -> ShowS)
-> Show ElmanSpec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ElmanSpec -> ShowS
showsPrec :: Int -> ElmanSpec -> ShowS
$cshow :: ElmanSpec -> String
show :: ElmanSpec -> String
$cshowList :: [ElmanSpec] -> ShowS
showList :: [ElmanSpec] -> ShowS
Show)

data ElmanCell = ElmanCell
  { ElmanCell -> Parameter
weightsIH :: Parameter,
    ElmanCell -> Parameter
weightsHH :: Parameter,
    ElmanCell -> Parameter
biasIH :: Parameter,
    ElmanCell -> Parameter
biasHH :: Parameter
  }
  deriving ((forall x. ElmanCell -> Rep ElmanCell x)
-> (forall x. Rep ElmanCell x -> ElmanCell) -> Generic ElmanCell
forall x. Rep ElmanCell x -> ElmanCell
forall x. ElmanCell -> Rep ElmanCell x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ElmanCell -> Rep ElmanCell x
from :: forall x. ElmanCell -> Rep ElmanCell x
$cto :: forall x. Rep ElmanCell x -> ElmanCell
to :: forall x. Rep ElmanCell x -> ElmanCell
Generic, Int -> ElmanCell -> ShowS
[ElmanCell] -> ShowS
ElmanCell -> String
(Int -> ElmanCell -> ShowS)
-> (ElmanCell -> String)
-> ([ElmanCell] -> ShowS)
-> Show ElmanCell
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ElmanCell -> ShowS
showsPrec :: Int -> ElmanCell -> ShowS
$cshow :: ElmanCell -> String
show :: ElmanCell -> String
$cshowList :: [ElmanCell] -> ShowS
showList :: [ElmanCell] -> ShowS
Show)

elmanCellForward ::
  -- | cell parameters
  ElmanCell ->
  -- | input
  Tensor ->
  -- | hidden
  Tensor ->
  -- | output
  Tensor
elmanCellForward :: ElmanCell -> Tensor -> Tensor -> Tensor
elmanCellForward ElmanCell {Parameter
weightsIH :: ElmanCell -> Parameter
weightsHH :: ElmanCell -> Parameter
biasIH :: ElmanCell -> Parameter
biasHH :: ElmanCell -> Parameter
weightsIH :: Parameter
weightsHH :: Parameter
biasIH :: Parameter
biasHH :: Parameter
..} Tensor
input Tensor
hidden =
  Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
rnnReluCell Tensor
weightsIH' Tensor
weightsHH' Tensor
biasIH' Tensor
biasHH' Tensor
hidden Tensor
input
  where
    weightsIH' :: Tensor
weightsIH' = Parameter -> Tensor
toDependent Parameter
weightsIH
    weightsHH' :: Tensor
weightsHH' = Parameter -> Tensor
toDependent Parameter
weightsHH
    biasIH' :: Tensor
biasIH' = Parameter -> Tensor
toDependent Parameter
biasIH
    biasHH' :: Tensor
biasHH' = Parameter -> Tensor
toDependent Parameter
biasIH

instance Parameterized ElmanCell

instance Randomizable ElmanSpec ElmanCell where
  sample :: ElmanSpec -> IO ElmanCell
sample ElmanSpec {Int
inputSize :: ElmanSpec -> Int
hiddenSize :: ElmanSpec -> Int
inputSize :: Int
hiddenSize :: Int
..} = do
    Parameter
weightsIH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
inputSize]
    Parameter
weightsHH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize, Int
hiddenSize]
    Parameter
biasIH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
    Parameter
biasHH <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Int] -> IO Tensor
randnIO' [Int
hiddenSize]
    ElmanCell -> IO ElmanCell
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ElmanCell -> IO ElmanCell) -> ElmanCell -> IO ElmanCell
forall a b. (a -> b) -> a -> b
$ Parameter -> Parameter -> Parameter -> Parameter -> ElmanCell
ElmanCell Parameter
weightsIH Parameter
weightsHH Parameter
biasIH Parameter
biasHH