{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.NN.Recurrent.Cell.GRU where

import GHC.Generics
import Torch

data GRUSpec = GRUSpec
  { GRUSpec -> Int
inputSize :: Int,
    GRUSpec -> Int
hiddenSize :: Int
  }
  deriving (GRUSpec -> GRUSpec -> Bool
(GRUSpec -> GRUSpec -> Bool)
-> (GRUSpec -> GRUSpec -> Bool) -> Eq GRUSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: GRUSpec -> GRUSpec -> Bool
== :: GRUSpec -> GRUSpec -> Bool
$c/= :: GRUSpec -> GRUSpec -> Bool
/= :: GRUSpec -> GRUSpec -> Bool
Eq, Int -> GRUSpec -> ShowS
[GRUSpec] -> ShowS
GRUSpec -> String
(Int -> GRUSpec -> ShowS)
-> (GRUSpec -> String) -> ([GRUSpec] -> ShowS) -> Show GRUSpec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GRUSpec -> ShowS
showsPrec :: Int -> GRUSpec -> ShowS
$cshow :: GRUSpec -> String
show :: GRUSpec -> String
$cshowList :: [GRUSpec] -> ShowS
showList :: [GRUSpec] -> ShowS
Show)

data GRUCell = GRUCell
  { GRUCell -> Parameter
weightsIH :: Parameter,
    GRUCell -> Parameter
weightsHH :: Parameter,
    GRUCell -> Parameter
biasIH :: Parameter,
    GRUCell -> Parameter
biasHH :: Parameter
  }
  deriving ((forall x. GRUCell -> Rep GRUCell x)
-> (forall x. Rep GRUCell x -> GRUCell) -> Generic GRUCell
forall x. Rep GRUCell x -> GRUCell
forall x. GRUCell -> Rep GRUCell x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. GRUCell -> Rep GRUCell x
from :: forall x. GRUCell -> Rep GRUCell x
$cto :: forall x. Rep GRUCell x -> GRUCell
to :: forall x. Rep GRUCell x -> GRUCell
Generic, Int -> GRUCell -> ShowS
[GRUCell] -> ShowS
GRUCell -> String
(Int -> GRUCell -> ShowS)
-> (GRUCell -> String) -> ([GRUCell] -> ShowS) -> Show GRUCell
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GRUCell -> ShowS
showsPrec :: Int -> GRUCell -> ShowS
$cshow :: GRUCell -> String
show :: GRUCell -> String
$cshowList :: [GRUCell] -> ShowS
showList :: [GRUCell] -> ShowS
Show)

gruCellForward ::
  -- | cell parameters
  GRUCell ->
  -- | input
  Tensor ->
  -- | hidden
  Tensor ->
  -- | output
  Tensor
gruCellForward :: GRUCell -> Tensor -> Tensor -> Tensor
gruCellForward GRUCell {Parameter
weightsIH :: GRUCell -> Parameter
weightsHH :: GRUCell -> Parameter
biasIH :: GRUCell -> Parameter
biasHH :: GRUCell -> Parameter
weightsIH :: Parameter
weightsHH :: Parameter
biasIH :: Parameter
biasHH :: Parameter
..} Tensor
input Tensor
hidden =
  Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor -> Tensor
gruCell Tensor
weightsIH' Tensor
weightsHH' Tensor
biasIH' Tensor
biasHH' Tensor
hidden Tensor
input
  where
    weightsIH' :: Tensor
weightsIH' = Parameter -> Tensor
toDependent Parameter
weightsIH
    weightsHH' :: Tensor
weightsHH' = Parameter -> Tensor
toDependent Parameter
weightsHH
    biasIH' :: Tensor
biasIH' = Parameter -> Tensor
toDependent Parameter
biasIH
    biasHH' :: Tensor
biasHH' = Parameter -> Tensor
toDependent Parameter
biasHH

instance Parameterized GRUCell

instance Randomizable GRUSpec GRUCell where
  sample :: GRUSpec -> IO GRUCell
sample GRUSpec {Int
inputSize :: GRUSpec -> Int
hiddenSize :: GRUSpec -> Int
inputSize :: Int
hiddenSize :: Int
..} = do
    -- https://pytorch.org/docs/stable/generated/torch.nn.GRUCell.html
    Parameter
weightsIH' <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale (Tensor -> Tensor) -> IO Tensor -> IO Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hiddenSize, Int
inputSize]
    Parameter
weightsHH' <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale (Tensor -> Tensor) -> IO Tensor -> IO Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hiddenSize, Int
hiddenSize]
    Parameter
biasIH' <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale (Tensor -> Tensor) -> IO Tensor -> IO Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hiddenSize]
    Parameter
biasHH' <- Tensor -> IO Parameter
makeIndependent (Tensor -> IO Parameter) -> IO Tensor -> IO Parameter
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor -> Tensor
initScale (Tensor -> Tensor) -> IO Tensor -> IO Tensor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> IO Tensor
randIO' [Int
3 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
hiddenSize]
    GRUCell -> IO GRUCell
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (GRUCell -> IO GRUCell) -> GRUCell -> IO GRUCell
forall a b. (a -> b) -> a -> b
$
      GRUCell
        { weightsIH :: Parameter
weightsIH = Parameter
weightsIH',
          weightsHH :: Parameter
weightsHH = Parameter
weightsHH',
          biasIH :: Parameter
biasIH = Parameter
biasIH',
          biasHH :: Parameter
biasHH = Parameter
biasHH'
        }
    where
      scale :: Float
scale = Float -> Float
forall a. Floating a => a -> a
Prelude.sqrt (Float -> Float) -> Float -> Float
forall a b. (a -> b) -> a -> b
$ Float
1.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
hiddenSize :: Float
      initScale :: Tensor -> Tensor
initScale = Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
subScalar Float
scale (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar Float
scale (Tensor -> Tensor) -> (Tensor -> Tensor) -> Tensor -> Tensor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Float -> Tensor -> Tensor
forall a. Scalar a => a -> Tensor -> Tensor
mulScalar (Float
2.0 :: Float)