{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module Torch.Typed.NN.Recurrent.GRU where

import Data.Kind
import Data.Proxy (Proxy (..))
import Foreign.ForeignPtr
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.Environment
import System.IO.Unsafe
import qualified Torch.Autograd as A
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.NN as A
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (sqrt)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Recurrent.Auxiliary
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (tanh)

data
  GRULayerSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = GRULayerSpec
  deriving (Int
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
[GRULayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
GRULayerSpec inputSize hiddenSize directionality dtype device
-> String
(Int
 -> GRULayerSpec inputSize hiddenSize directionality dtype device
 -> ShowS)
-> (GRULayerSpec inputSize hiddenSize directionality dtype device
    -> String)
-> ([GRULayerSpec inputSize hiddenSize directionality dtype device]
    -> ShowS)
-> Show
     (GRULayerSpec inputSize hiddenSize directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Int
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
[GRULayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Int
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
showsPrec :: Int
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
-> String
show :: GRULayerSpec inputSize hiddenSize directionality dtype device
-> String
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
[GRULayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
showList :: [GRULayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
Show, GRULayerSpec inputSize hiddenSize directionality dtype device
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> Bool
(GRULayerSpec inputSize hiddenSize directionality dtype device
 -> GRULayerSpec inputSize hiddenSize directionality dtype device
 -> Bool)
-> (GRULayerSpec inputSize hiddenSize directionality dtype device
    -> GRULayerSpec inputSize hiddenSize directionality dtype device
    -> Bool)
-> Eq
     (GRULayerSpec inputSize hiddenSize directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> Bool
== :: GRULayerSpec inputSize hiddenSize directionality dtype device
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> Bool
/= :: GRULayerSpec inputSize hiddenSize directionality dtype device
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> Bool
Eq)

data
  GRULayer
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  GRUUnidirectionalLayer ::
    Parameter device dtype (GRUWIShape hiddenSize inputSize) ->
    Parameter device dtype (GRUWHShape hiddenSize inputSize) ->
    Parameter device dtype (GRUBIShape hiddenSize inputSize) ->
    Parameter device dtype (GRUBHShape hiddenSize inputSize) ->
    GRULayer inputSize hiddenSize 'Unidirectional dtype device
  GRUBidirectionalLayer ::
    Parameter device dtype (GRUWIShape hiddenSize inputSize) ->
    Parameter device dtype (GRUWHShape hiddenSize inputSize) ->
    Parameter device dtype (GRUBIShape hiddenSize inputSize) ->
    Parameter device dtype (GRUBHShape hiddenSize inputSize) ->
    Parameter device dtype (GRUWIShape hiddenSize inputSize) ->
    Parameter device dtype (GRUWHShape hiddenSize inputSize) ->
    Parameter device dtype (GRUBIShape hiddenSize inputSize) ->
    Parameter device dtype (GRUBHShape hiddenSize inputSize) ->
    GRULayer inputSize hiddenSize 'Bidirectional dtype device

deriving instance Show (GRULayer inputSize hiddenSize directionality dtype device)

instance Parameterized (GRULayer inputSize hiddenSize 'Unidirectional dtype device) where
  type
    Parameters (GRULayer inputSize hiddenSize 'Unidirectional dtype device) =
      '[ Parameter device dtype (GRUWIShape hiddenSize inputSize),
         Parameter device dtype (GRUWHShape hiddenSize inputSize),
         Parameter device dtype (GRUBIShape hiddenSize inputSize),
         Parameter device dtype (GRUBHShape hiddenSize inputSize)
       ]
  flattenParameters :: GRULayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
     (Parameters
        (GRULayer inputSize hiddenSize 'Unidirectional dtype device))
flattenParameters (GRUUnidirectionalLayer Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh) =
    Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> HList '[]
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil
  replaceParameters :: GRULayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
     (Parameters
        (GRULayer inputSize hiddenSize 'Unidirectional dtype device))
-> GRULayer inputSize hiddenSize 'Unidirectional dtype device
replaceParameters GRULayer inputSize hiddenSize 'Unidirectional dtype device
_ (Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi :. Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh :. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi :. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh :. HList '[]
R:HListk[] Type
HNil) =
    Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> GRULayer inputSize hiddenSize 'Unidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> GRULayer inputSize hiddenSize 'Unidirectional dtype device
GRUUnidirectionalLayer Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh

instance Parameterized (GRULayer inputSize hiddenSize 'Bidirectional dtype device) where
  type
    Parameters (GRULayer inputSize hiddenSize 'Bidirectional dtype device) =
      '[ Parameter device dtype (GRUWIShape hiddenSize inputSize),
         Parameter device dtype (GRUWHShape hiddenSize inputSize),
         Parameter device dtype (GRUBIShape hiddenSize inputSize),
         Parameter device dtype (GRUBHShape hiddenSize inputSize),
         Parameter device dtype (GRUWIShape hiddenSize inputSize),
         Parameter device dtype (GRUWHShape hiddenSize inputSize),
         Parameter device dtype (GRUBIShape hiddenSize inputSize),
         Parameter device dtype (GRUBHShape hiddenSize inputSize)
       ]
  flattenParameters :: GRULayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
     (Parameters
        (GRULayer inputSize hiddenSize 'Bidirectional dtype device))
flattenParameters (GRUBidirectionalLayer Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi' Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh' Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi' Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh') =
    Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi' Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUWIShape hiddenSize inputSize),
       Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh' Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUWHShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi' Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize),
       Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh' Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> HList '[]
-> HList
     '[Parameter device dtype (GRUBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil
  replaceParameters :: GRULayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
     (Parameters
        (GRULayer inputSize hiddenSize 'Bidirectional dtype device))
-> GRULayer inputSize hiddenSize 'Bidirectional dtype device
replaceParameters GRULayer inputSize hiddenSize 'Bidirectional dtype device
_ (Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi :. Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh :. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi :. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh :. Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi' :. Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh' :. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi' :. Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh' :. HList '[]
R:HListk[] Type
HNil) =
    Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> GRULayer inputSize hiddenSize 'Bidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> GRULayer inputSize hiddenSize 'Bidirectional dtype device
GRUBidirectionalLayer Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh Parameter device dtype (GRUWIShape hiddenSize inputSize)
wi' Parameter device dtype (GRUWHShape hiddenSize inputSize)
wh' Parameter device dtype (GRUBIShape hiddenSize inputSize)
bi' Parameter device dtype (GRUBIShape hiddenSize inputSize)
bh'

instance
  ( RandDTypeIsValid device dtype,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownDType dtype,
    KnownDevice device
  ) =>
  A.Randomizable
    (GRULayerSpec inputSize hiddenSize 'Unidirectional dtype device)
    (GRULayer inputSize hiddenSize 'Unidirectional dtype device)
  where
  sample :: GRULayerSpec inputSize hiddenSize 'Unidirectional dtype device
-> IO (GRULayer inputSize hiddenSize 'Unidirectional dtype device)
sample GRULayerSpec inputSize hiddenSize 'Unidirectional dtype device
_ =
    Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype '[3 * hiddenSize]
-> Parameter device dtype '[3 * hiddenSize]
-> GRULayer inputSize hiddenSize 'Unidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> GRULayer inputSize hiddenSize 'Unidirectional dtype device
GRUUnidirectionalLayer
      (Parameter device dtype (GRUWIShape hiddenSize inputSize)
 -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
 -> Parameter device dtype '[3 * hiddenSize]
 -> Parameter device dtype '[3 * hiddenSize]
 -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
-> IO
     (Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype (GRUWIShape hiddenSize inputSize)
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (GRUWIShape hiddenSize inputSize)
 -> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize)))
-> IO (Tensor device dtype (GRUWIShape hiddenSize inputSize))
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (GRUWIShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU)
      IO
  (Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
-> IO
     (Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (GRUWHShape hiddenSize inputSize)
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (GRUWHShape hiddenSize inputSize)
 -> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize)))
-> IO (Tensor device dtype (GRUWHShape hiddenSize inputSize))
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (GRUWHShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU)
      IO
  (Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype '[3 * hiddenSize])
-> IO
     (Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[3 * hiddenSize]
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[3 * hiddenSize]
 -> IO (Parameter device dtype '[3 * hiddenSize]))
-> IO (Tensor device dtype '[3 * hiddenSize])
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[3 * hiddenSize]
-> IO (Tensor device dtype '[3 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[3 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype '[3 * hiddenSize])
-> IO (GRULayer inputSize hiddenSize 'Unidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[3 * hiddenSize]
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[3 * hiddenSize]
 -> IO (Parameter device dtype '[3 * hiddenSize]))
-> IO (Tensor device dtype '[3 * hiddenSize])
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[3 * hiddenSize]
-> IO (Tensor device dtype '[3 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[3 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)

instance
  ( RandDTypeIsValid device dtype,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownDType dtype,
    KnownDevice device
  ) =>
  A.Randomizable
    (GRULayerSpec inputSize hiddenSize 'Bidirectional dtype device)
    (GRULayer inputSize hiddenSize 'Bidirectional dtype device)
  where
  sample :: GRULayerSpec inputSize hiddenSize 'Bidirectional dtype device
-> IO (GRULayer inputSize hiddenSize 'Bidirectional dtype device)
sample GRULayerSpec inputSize hiddenSize 'Bidirectional dtype device
_ =
    Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype '[3 * hiddenSize]
-> Parameter device dtype '[3 * hiddenSize]
-> Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype '[3 * hiddenSize]
-> Parameter device dtype '[3 * hiddenSize]
-> GRULayer inputSize hiddenSize 'Bidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWIShape hiddenSize inputSize)
-> Parameter device dtype (GRUWHShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> Parameter device dtype (GRUBIShape hiddenSize inputSize)
-> GRULayer inputSize hiddenSize 'Bidirectional dtype device
GRUBidirectionalLayer
      (Parameter device dtype (GRUWIShape hiddenSize inputSize)
 -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
 -> Parameter device dtype '[3 * hiddenSize]
 -> Parameter device dtype '[3 * hiddenSize]
 -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
 -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
 -> Parameter device dtype '[3 * hiddenSize]
 -> Parameter device dtype '[3 * hiddenSize]
 -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
-> IO
     (Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
      -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype (GRUWIShape hiddenSize inputSize)
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (GRUWIShape hiddenSize inputSize)
 -> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize)))
-> IO (Tensor device dtype (GRUWIShape hiddenSize inputSize))
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (GRUWIShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU)
      IO
  (Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
   -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
-> IO
     (Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
      -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (GRUWHShape hiddenSize inputSize)
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (GRUWHShape hiddenSize inputSize)
 -> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize)))
-> IO (Tensor device dtype (GRUWHShape hiddenSize inputSize))
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (GRUWHShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU)
      IO
  (Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
   -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[3 * hiddenSize])
-> IO
     (Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
      -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[3 * hiddenSize]
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[3 * hiddenSize]
 -> IO (Parameter device dtype '[3 * hiddenSize]))
-> IO (Tensor device dtype '[3 * hiddenSize])
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[3 * hiddenSize]
-> IO (Tensor device dtype '[3 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[3 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype (GRUWIShape hiddenSize inputSize)
   -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[3 * hiddenSize])
-> IO
     (Parameter device dtype (GRUWIShape hiddenSize inputSize)
      -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[3 * hiddenSize]
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[3 * hiddenSize]
 -> IO (Parameter device dtype '[3 * hiddenSize]))
-> IO (Tensor device dtype '[3 * hiddenSize])
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[3 * hiddenSize]
-> IO (Tensor device dtype '[3 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[3 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype (GRUWIShape hiddenSize inputSize)
   -> Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
-> IO
     (Parameter device dtype (GRUWHShape hiddenSize inputSize)
      -> Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (GRUWIShape hiddenSize inputSize)
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (GRUWIShape hiddenSize inputSize)
 -> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize)))
-> IO (Tensor device dtype (GRUWIShape hiddenSize inputSize))
-> IO (Parameter device dtype (GRUWIShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (GRUWIShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU)
      IO
  (Parameter device dtype (GRUWHShape hiddenSize inputSize)
   -> Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
-> IO
     (Parameter device dtype '[3 * hiddenSize]
      -> Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (GRUWHShape hiddenSize inputSize)
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (GRUWHShape hiddenSize inputSize)
 -> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize)))
-> IO (Tensor device dtype (GRUWHShape hiddenSize inputSize))
-> IO (Parameter device dtype (GRUWHShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (GRUWHShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU)
      IO
  (Parameter device dtype '[3 * hiddenSize]
   -> Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[3 * hiddenSize])
-> IO
     (Parameter device dtype '[3 * hiddenSize]
      -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[3 * hiddenSize]
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[3 * hiddenSize]
 -> IO (Parameter device dtype '[3 * hiddenSize]))
-> IO (Tensor device dtype '[3 * hiddenSize])
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[3 * hiddenSize]
-> IO (Tensor device dtype '[3 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[3 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype '[3 * hiddenSize]
   -> GRULayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[3 * hiddenSize])
-> IO (GRULayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[3 * hiddenSize]
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[3 * hiddenSize]
 -> IO (Parameter device dtype '[3 * hiddenSize]))
-> IO (Tensor device dtype '[3 * hiddenSize])
-> IO (Parameter device dtype '[3 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[3 * hiddenSize]
-> IO (Tensor device dtype '[3 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[3 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)

data
  GRULayerStackSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = GRULayerStackSpec
  deriving (Int
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
[GRULayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
(Int
 -> GRULayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
 -> ShowS)
-> (GRULayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
    -> String)
-> ([GRULayerStackSpec
       inputSize hiddenSize numLayers directionality dtype device]
    -> ShowS)
-> Show
     (GRULayerStackSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[GRULayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
showsPrec :: Int
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
show :: GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[GRULayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
showList :: [GRULayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
Show, GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
(GRULayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device
 -> GRULayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
 -> Bool)
-> (GRULayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
    -> GRULayerStackSpec
         inputSize hiddenSize numLayers directionality dtype device
    -> Bool)
-> Eq
     (GRULayerStackSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
== :: GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
/= :: GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
Eq)

-- Input-to-hidden, hidden-to-hidden, and bias parameters for a mulilayered
-- (and optionally) bidirectional GRU.
--
data
  GRULayerStack
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  GRULayer1 ::
    GRULayer inputSize hiddenSize directionality dtype device ->
    GRULayerStack inputSize hiddenSize 1 directionality dtype device
  GRULayerK ::
    GRULayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device ->
    GRULayerStack inputSize hiddenSize numLayers directionality dtype device ->
    GRULayerStack inputSize hiddenSize (numLayers + 1) directionality dtype device

deriving instance Show (GRULayerStack inputSize hiddenSize numLayers directionality dtype device)

class GRULayerStackParameterized (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
  type GRULayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device :: [Type]
  gruLayerStackFlattenParameters ::
    Proxy flag ->
    GRULayerStack inputSize hiddenSize numLayers directionality dtype device ->
    HList (GRULayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device)
  gruLayerStackReplaceParameters ::
    Proxy flag ->
    GRULayerStack inputSize hiddenSize numLayers directionality dtype device ->
    HList (GRULayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device) ->
    GRULayerStack inputSize hiddenSize numLayers directionality dtype device

instance
  Parameterized (GRULayer inputSize hiddenSize directionality dtype device) =>
  GRULayerStackParameterized 'False inputSize hiddenSize 1 directionality dtype device
  where
  type
    GRULayerStackParameters 'False inputSize hiddenSize 1 directionality dtype device =
      Parameters (GRULayer inputSize hiddenSize directionality dtype device)
  gruLayerStackFlattenParameters :: Proxy 'False
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
-> HList
     (GRULayerStackParameters
        'False inputSize hiddenSize 1 directionality dtype device)
gruLayerStackFlattenParameters Proxy 'False
_ (GRULayer1 GRULayer inputSize hiddenSize directionality dtype device
gruLayer) = GRULayer inputSize hiddenSize directionality dtype device
-> HList
     (Parameters
        (GRULayer inputSize hiddenSize directionality dtype device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters GRULayer inputSize hiddenSize directionality dtype device
gruLayer
  gruLayerStackReplaceParameters :: Proxy 'False
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
-> HList
     (GRULayerStackParameters
        'False inputSize hiddenSize 1 directionality dtype device)
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
gruLayerStackReplaceParameters Proxy 'False
_ (GRULayer1 GRULayer inputSize hiddenSize directionality dtype device
gruLayer) HList
  (GRULayerStackParameters
     'False inputSize hiddenSize 1 directionality dtype device)
parameters = GRULayer inputSize hiddenSize directionality dtype device
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayer inputSize hiddenSize directionality dtype device
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
GRULayer1 (GRULayer inputSize hiddenSize directionality dtype device
 -> GRULayerStack
      inputSize hiddenSize 1 directionality dtype device)
-> GRULayer inputSize hiddenSize directionality dtype device
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
forall a b. (a -> b) -> a -> b
$ GRULayer inputSize hiddenSize directionality dtype device
-> HList
     (Parameters
        (GRULayer inputSize hiddenSize directionality dtype device))
-> GRULayer inputSize hiddenSize directionality dtype device
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters GRULayer inputSize hiddenSize directionality dtype device
gruLayer HList
  (Parameters
     (GRULayer inputSize hiddenSize directionality dtype device))
HList
  (GRULayerStackParameters
     'False inputSize hiddenSize 1 directionality dtype device)
parameters

instance
  ( Parameterized
      ( GRULayer
          (hiddenSize * NumberOfDirections directionality)
          hiddenSize
          directionality
          dtype
          device
      ),
    Parameterized (GRULayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device),
    HAppendFD
      (Parameters (GRULayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device))
      (Parameters (GRULayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device))
      (Parameters (GRULayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device) ++ Parameters (GRULayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)),
    1 <= numLayers,
    numLayersM1 ~ numLayers - 1,
    0 <= numLayersM1
  ) =>
  GRULayerStackParameterized 'True inputSize hiddenSize numLayers directionality dtype device
  where
  type
    GRULayerStackParameters 'True inputSize hiddenSize numLayers directionality dtype device =
      Parameters (GRULayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
        ++ Parameters (GRULayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
  gruLayerStackFlattenParameters :: Proxy 'True
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (GRULayerStackParameters
        'True inputSize hiddenSize numLayers directionality dtype device)
gruLayerStackFlattenParameters Proxy 'True
_ (GRULayerK GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
gruLayer GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gruLayerStack) =
    let parameters :: HList
  (Parameters
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters = GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> HList
     (Parameters
        (GRULayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
gruLayer
        parameters' :: HList
  (Parameters
     (GRULayerStack
        inputSize hiddenSize numLayersM1 directionality dtype device))
parameters' = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters @(GRULayerStack inputSize hiddenSize numLayersM1 directionality dtype device) GRULayerStack
  inputSize hiddenSize numLayersM1 directionality dtype device
GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gruLayerStack
     in HList
  (Parameters
     (GRULayerStack
        inputSize hiddenSize numLayersM1 directionality dtype device))
HList
  (GRULayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
parameters' HList
  (GRULayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
-> HList
     (Parameters
        (GRULayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
-> HList
     (GRULayerStackParameters
        (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
        inputSize
        hiddenSize
        numLayersM1
        directionality
        dtype
        device
      ++ Parameters
           (GRULayer
              (hiddenSize * NumberOfDirections directionality)
              hiddenSize
              directionality
              dtype
              device))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList
  (Parameters
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters
  gruLayerStackReplaceParameters :: Proxy 'True
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (GRULayerStackParameters
        'True inputSize hiddenSize numLayers directionality dtype device)
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
gruLayerStackReplaceParameters Proxy 'True
_ (GRULayerK GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
gruLayer GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gruLayerStack) HList
  (GRULayerStackParameters
     'True inputSize hiddenSize numLayers directionality dtype device)
parameters'' =
    let (HList
  (GRULayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
parameters', HList
  (Parameters
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters) = HList
  (GRULayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device
   ++ Parameters
        (GRULayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
-> (HList
      (GRULayerStackParameters
         (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
         inputSize
         hiddenSize
         numLayersM1
         directionality
         dtype
         device),
    HList
      (Parameters
         (GRULayer
            (hiddenSize * NumberOfDirections directionality)
            hiddenSize
            directionality
            dtype
            device)))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList
  (GRULayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device
   ++ Parameters
        (GRULayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
HList
  (GRULayerStackParameters
     'True inputSize hiddenSize numLayers directionality dtype device)
parameters''
        gruLayer' :: GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
gruLayer' = GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> HList
     (Parameters
        (GRULayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
-> GRULayer
     (hiddenSize * NumberOfDirections directionality)
     hiddenSize
     directionality
     dtype
     device
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
gruLayer HList
  (Parameters
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters
        gruLayerStack' :: GRULayerStack
  inputSize hiddenSize numLayersM1 directionality dtype device
gruLayerStack' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters @(GRULayerStack inputSize hiddenSize numLayersM1 directionality dtype device) GRULayerStack
  inputSize hiddenSize numLayersM1 directionality dtype device
GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gruLayerStack HList
  (Parameters
     (GRULayerStack
        inputSize hiddenSize numLayersM1 directionality dtype device))
HList
  (GRULayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
parameters'
     in GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> GRULayerStack
     inputSize hiddenSize numLayersM1 directionality dtype device
-> GRULayerStack
     inputSize hiddenSize (numLayersM1 + 1) directionality dtype device
forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
GRULayerK GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
gruLayer' GRULayerStack
  inputSize hiddenSize numLayersM1 directionality dtype device
gruLayerStack'

instance
  ( 1 <= numLayers,
    (2 <=? numLayers) ~ flag,
    GRULayerStackParameterized flag inputSize hiddenSize numLayers directionality dtype device
  ) =>
  Parameterized (GRULayerStack inputSize hiddenSize numLayers directionality dtype device)
  where
  type
    Parameters (GRULayerStack inputSize hiddenSize numLayers directionality dtype device) =
      GRULayerStackParameters (2 <=? numLayers) inputSize hiddenSize numLayers directionality dtype device
  flattenParameters :: GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (GRULayerStack
           inputSize hiddenSize numLayers directionality dtype device))
flattenParameters = Proxy flag
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (GRULayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerStackParameterized
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (GRULayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
gruLayerStackFlattenParameters (Proxy flag
forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)
  replaceParameters :: GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (GRULayerStack
           inputSize hiddenSize numLayers directionality dtype device))
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
replaceParameters = Proxy flag
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (GRULayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerStackParameterized
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (GRULayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
gruLayerStackReplaceParameters (Proxy flag
forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)

class GRULayerStackRandomizable (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
  gruLayerStackSample ::
    Proxy flag ->
    GRULayerStackSpec inputSize hiddenSize numLayers directionality dtype device ->
    IO (GRULayerStack inputSize hiddenSize numLayers directionality dtype device)

instance
  ( A.Randomizable
      (GRULayerSpec inputSize hiddenSize directionality dtype device)
      (GRULayer inputSize hiddenSize directionality dtype device)
  ) =>
  GRULayerStackRandomizable 'False inputSize hiddenSize 1 directionality dtype device
  where
  gruLayerStackSample :: Proxy 'False
-> GRULayerStackSpec
     inputSize hiddenSize 1 directionality dtype device
-> IO
     (GRULayerStack inputSize hiddenSize 1 directionality dtype device)
gruLayerStackSample Proxy 'False
_ GRULayerStackSpec
  inputSize hiddenSize 1 directionality dtype device
_ = GRULayer inputSize hiddenSize directionality dtype device
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayer inputSize hiddenSize directionality dtype device
-> GRULayerStack inputSize hiddenSize 1 directionality dtype device
GRULayer1 (GRULayer inputSize hiddenSize directionality dtype device
 -> GRULayerStack
      inputSize hiddenSize 1 directionality dtype device)
-> IO (GRULayer inputSize hiddenSize directionality dtype device)
-> IO
     (GRULayerStack inputSize hiddenSize 1 directionality dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (GRULayerSpec inputSize hiddenSize directionality dtype device
-> IO (GRULayer inputSize hiddenSize directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
sample (GRULayerSpec inputSize hiddenSize directionality dtype device
 -> IO (GRULayer inputSize hiddenSize directionality dtype device))
-> GRULayerSpec inputSize hiddenSize directionality dtype device
-> IO (GRULayer inputSize hiddenSize directionality dtype device)
forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
GRULayerSpec @inputSize @hiddenSize @directionality @dtype @device)

instance
  ( 1 <= numLayers,
    A.Randomizable
      (GRULayerSpec (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
      (GRULayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device),
    A.Randomizable
      (GRULayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
      (GRULayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
  ) =>
  GRULayerStackRandomizable 'True inputSize hiddenSize numLayers directionality dtype device
  where
  gruLayerStackSample :: Proxy 'True
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
gruLayerStackSample Proxy 'True
_ GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
_ =
    GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> GRULayerStack
     inputSize hiddenSize (numLayers - 1) directionality dtype device
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> GRULayerStack
     inputSize hiddenSize (numLayers - 1) directionality dtype device
-> GRULayerStack
     inputSize
     hiddenSize
     ((numLayers - 1) + 1)
     directionality
     dtype
     device
forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
GRULayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
GRULayerK
      (GRULayer
   (hiddenSize * NumberOfDirections directionality)
   hiddenSize
   directionality
   dtype
   device
 -> GRULayerStack
      inputSize hiddenSize (numLayers - 1) directionality dtype device
 -> GRULayerStack
      inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
-> IO
     (GRULayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device
      -> GRULayerStack
           inputSize hiddenSize numLayers directionality dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (GRULayerSpec
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> IO
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
forall spec f. Randomizable spec f => spec -> IO f
sample (GRULayerSpec
   (hiddenSize * NumberOfDirections directionality)
   hiddenSize
   directionality
   dtype
   device
 -> IO
      (GRULayer
         (hiddenSize * NumberOfDirections directionality)
         hiddenSize
         directionality
         dtype
         device))
-> GRULayerSpec
     (hiddenSize * NumberOfDirections directionality)
     hiddenSize
     directionality
     dtype
     device
-> IO
     (GRULayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerSpec inputSize hiddenSize directionality dtype device
GRULayerSpec @(hiddenSize * NumberOfDirections directionality) @hiddenSize @directionality @dtype @device)
      IO
  (GRULayerStack
     inputSize hiddenSize (numLayers - 1) directionality dtype device
   -> GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (GRULayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device)
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ( forall spec f. Randomizable spec f => spec -> IO f
sample
              @(GRULayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
              @(GRULayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
              (GRULayerStackSpec
   inputSize hiddenSize (numLayers - 1) directionality dtype device
 -> IO
      (GRULayerStack
         inputSize hiddenSize (numLayers - 1) directionality dtype device))
-> GRULayerStackSpec
     inputSize hiddenSize (numLayers - 1) directionality dtype device
-> IO
     (GRULayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device)
forall a b. (a -> b) -> a -> b
$ GRULayerStackSpec
  inputSize hiddenSize (numLayers - 1) directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
GRULayerStackSpec
          )

instance
  ( 1 <= numLayers,
    (2 <=? numLayers) ~ flag,
    RandDTypeIsValid device dtype,
    KnownDType dtype,
    KnownDevice device,
    GRULayerStackRandomizable flag inputSize hiddenSize numLayers directionality dtype device
  ) =>
  Randomizable
    (GRULayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
    (GRULayerStack inputSize hiddenSize numLayers directionality dtype device)
  where
  sample :: GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
sample = Proxy flag
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
GRULayerStackRandomizable
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> GRULayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
gruLayerStackSample (Proxy flag
forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)

newtype
  GRUSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = GRUSpec DropoutSpec
  deriving (Int
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
[GRUSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> String
(Int
 -> GRUSpec
      inputSize hiddenSize numLayers directionality dtype device
 -> ShowS)
-> (GRUSpec
      inputSize hiddenSize numLayers directionality dtype device
    -> String)
-> ([GRUSpec
       inputSize hiddenSize numLayers directionality dtype device]
    -> ShowS)
-> Show
     (GRUSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[GRUSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
showsPrec :: Int
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> String
show :: GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> String
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[GRUSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
showList :: [GRUSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
Show, (forall x.
 GRUSpec inputSize hiddenSize numLayers directionality dtype device
 -> Rep
      (GRUSpec
         inputSize hiddenSize numLayers directionality dtype device)
      x)
-> (forall x.
    Rep
      (GRUSpec
         inputSize hiddenSize numLayers directionality dtype device)
      x
    -> GRUSpec
         inputSize hiddenSize numLayers directionality dtype device)
-> Generic
     (GRUSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
  (GRUSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (GRUSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall x.
Rep
  (GRUSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
forall x.
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (GRUSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (GRUSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
from :: forall x.
GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (GRUSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
$cto :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
  (GRUSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
to :: forall x.
Rep
  (GRUSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> GRUSpec
     inputSize hiddenSize numLayers directionality dtype device
Generic)

data
  GRU
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  GRU ::
    (1 <= numLayers) =>
    { forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
gru_layer_stack :: GRULayerStack inputSize hiddenSize numLayers directionality dtype device,
      forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Dropout
gru_dropout :: Dropout
    } ->
    GRU inputSize hiddenSize numLayers directionality dtype device

deriving instance Show (GRU inputSize hiddenSize numLayers directionality dtype device)

instance
  (1 <= numLayers) =>
  Generic (GRU inputSize hiddenSize numLayers directionality dtype device)
  where
  type
    Rep (GRU inputSize hiddenSize numLayers directionality dtype device) =
      Rec0 (GRULayerStack inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 Dropout
  from :: forall x.
GRU inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (GRU inputSize hiddenSize numLayers directionality dtype device) x
from (GRU {Dropout
GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gru_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device
gru_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Dropout
gru_layer_stack :: GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gru_dropout :: Dropout
..}) = GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> K1
     R
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall k i c (p :: k). c -> K1 i c p
K1 GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
gru_layer_stack K1
  R
  (GRULayerStack
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> K1 R Dropout x
-> (:*:)
     (K1
        R
        (GRULayerStack
           inputSize hiddenSize numLayers directionality dtype device))
     (K1 R Dropout)
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Dropout -> K1 R Dropout x
forall k i c (p :: k). c -> K1 i c p
K1 Dropout
gru_dropout
  to :: forall x.
Rep
  (GRU inputSize hiddenSize numLayers directionality dtype device) x
-> GRU inputSize hiddenSize numLayers directionality dtype device
to (K1 GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
layerStack :*: K1 Dropout
dropout) = GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> GRU inputSize hiddenSize numLayers directionality dtype device
forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> GRU inputSize hiddenSize numLayers directionality dtype device
GRU GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
layerStack Dropout
dropout

instance
  ( 1 <= numLayers,
    Parameterized (GRULayerStack inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (GRULayerStack inputSize hiddenSize numLayers directionality dtype device))
      (Parameters Dropout)
      ( Parameters (GRULayerStack inputSize hiddenSize numLayers directionality dtype device)
          ++ Parameters Dropout
      )
  ) =>
  Parameterized (GRU inputSize hiddenSize numLayers directionality dtype device)

-- TODO: when we have cannonical initializers do this correctly:
-- https://github.com/pytorch/pytorch/issues/9221
-- https://discuss.pytorch.org/t/initializing-rnn-gru-and-gru-correctly/23605

-- | Helper to do xavier uniform initializations on weight matrices and
-- orthagonal initializations for the gates. (When implemented.)
xavierUniformGRU ::
  forall device dtype hiddenSize featureSize.
  ( KnownDType dtype,
    KnownNat hiddenSize,
    KnownNat featureSize,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU :: forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[3 * hiddenSize, featureSize])
xavierUniformGRU = do
  Tensor device dtype '[3 * hiddenSize, featureSize]
init <- IO (Tensor device dtype '[3 * hiddenSize, featureSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn :: IO (Tensor device dtype '[3 * hiddenSize, featureSize])
  Tensor -> Tensor device dtype '[3 * hiddenSize, featureSize]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
    (Tensor -> Tensor device dtype '[3 * hiddenSize, featureSize])
-> IO Tensor
-> IO (Tensor device dtype '[3 * hiddenSize, featureSize])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME
      (Tensor device dtype '[3 * hiddenSize, featureSize] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype '[3 * hiddenSize, featureSize]
init)
      (Float
5.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
3)
      (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> [Int]
shape @device @dtype @'[3 * hiddenSize, featureSize] Tensor device dtype '[3 * hiddenSize, featureSize]
init)

instance
  ( KnownDType dtype,
    KnownDevice device,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownNat (NumberOfDirections directionality),
    RandDTypeIsValid device dtype,
    A.Randomizable
      (GRULayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
      (GRULayerStack inputSize hiddenSize numLayers directionality dtype device),
    1 <= numLayers
  ) =>
  A.Randomizable
    (GRUSpec inputSize hiddenSize numLayers directionality dtype device)
    (GRU inputSize hiddenSize numLayers directionality dtype device)
  where
  sample :: GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
sample (GRUSpec DropoutSpec
dropoutSpec) =
    GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> GRU inputSize hiddenSize numLayers directionality dtype device
forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> GRU inputSize hiddenSize numLayers directionality dtype device
GRU
      (GRULayerStack
   inputSize hiddenSize numLayers directionality dtype device
 -> Dropout
 -> GRU inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Dropout
      -> GRU inputSize hiddenSize numLayers directionality dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRULayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRULayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
GRULayerStackSpec @inputSize @hiddenSize @numLayers @directionality @dtype @device)
      IO
  (Dropout
   -> GRU inputSize hiddenSize numLayers directionality dtype device)
-> IO Dropout
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropoutSpec

-- | A specification for a long, short-term memory layer.
data
  GRUWithInitSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (initialization :: RNNInitialization)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases and cell states.
  GRUWithZerosInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    GRUSpec inputSize hiddenSize numLayers directionality dtype device ->
    GRUWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases
  --   and user-provided cell states.
  GRUWithConstInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    GRUSpec inputSize hiddenSize numLayers directionality dtype device ->
    -- | The initial values of the hidden state
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    GRUWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases
  --   and learned cell states.
  GRUWithLearnedInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    GRUSpec inputSize hiddenSize numLayers directionality dtype device ->
    -- | The initial (learnable)
    -- values of the hidden state
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    GRUWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device

deriving instance Show (GRUWithInitSpec inputSize hiddenSize numLayers directionality initialization dtype device)

-- | A long, short-term memory layer with either fixed initial
-- states for the memory cells and hidden state or learnable
-- inital states for the memory cells and hidden state.
data
  GRUWithInit
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (initialization :: RNNInitialization)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  GRUWithConstInit ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    { forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> GRU inputSize hiddenSize numLayers directionality dtype device
gruWithConstInit_gru :: GRU inputSize hiddenSize numLayers directionality dtype device,
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
gruWithConstInit_h :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
    } ->
    GRUWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device
  GRUWithLearnedInit ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    { forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> GRU inputSize hiddenSize numLayers directionality dtype device
gruWithLearnedInit_gru :: GRU inputSize hiddenSize numLayers directionality dtype device,
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
gruWithLearnedInit_h :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
    } ->
    GRUWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'LearnedInitialization
      dtype
      device

deriving instance Show (GRUWithInit inputSize hiddenSize numLayers directionality initialization dtype device)

instance Generic (GRUWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) where
  type
    Rep (GRUWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) =
      Rec0 (GRU inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
  from :: forall x.
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Rep
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
     x
from (GRUWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
GRU inputSize hiddenSize numLayers directionality dtype device
gruWithConstInit_gru :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> GRU inputSize hiddenSize numLayers directionality dtype device
gruWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
gruWithConstInit_gru :: GRU inputSize hiddenSize numLayers directionality dtype device
gruWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..}) = GRU inputSize hiddenSize numLayers directionality dtype device
-> K1
     R
     (GRU inputSize hiddenSize numLayers directionality dtype device)
     x
forall k i c (p :: k). c -> K1 i c p
K1 GRU inputSize hiddenSize numLayers directionality dtype device
gruWithConstInit_gru K1
  R
  (GRU inputSize hiddenSize numLayers directionality dtype device)
  x
-> K1
     R
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
-> (:*:)
     (K1
        R (GRU inputSize hiddenSize numLayers directionality dtype device))
     (K1
        R
        (Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> K1
     R
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
forall k i c (p :: k). c -> K1 i c p
K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
gruWithConstInit_h
  to :: forall x.
Rep
  (GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device)
  x
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
to (K1 GRU inputSize hiddenSize numLayers directionality dtype device
gru :*: K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) = GRU inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
GRUWithConstInit GRU inputSize hiddenSize numLayers directionality dtype device
gru Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance Generic (GRUWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) where
  type
    Rep (GRUWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) =
      Rec0 (GRU inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
  from :: forall x.
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Rep
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
     x
from (GRUWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
GRU inputSize hiddenSize numLayers directionality dtype device
gruWithLearnedInit_gru :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> GRU inputSize hiddenSize numLayers directionality dtype device
gruWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
gruWithLearnedInit_gru :: GRU inputSize hiddenSize numLayers directionality dtype device
gruWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..}) = GRU inputSize hiddenSize numLayers directionality dtype device
-> K1
     R
     (GRU inputSize hiddenSize numLayers directionality dtype device)
     x
forall k i c (p :: k). c -> K1 i c p
K1 GRU inputSize hiddenSize numLayers directionality dtype device
gruWithLearnedInit_gru K1
  R
  (GRU inputSize hiddenSize numLayers directionality dtype device)
  x
-> K1
     R
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
-> (:*:)
     (K1
        R (GRU inputSize hiddenSize numLayers directionality dtype device))
     (K1
        R
        (Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> K1
     R
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
forall k i c (p :: k). c -> K1 i c p
K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
gruWithLearnedInit_h
  to :: forall x.
Rep
  (GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device)
  x
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
to (K1 GRU inputSize hiddenSize numLayers directionality dtype device
gru :*: K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) = GRU inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
GRUWithLearnedInit GRU inputSize hiddenSize numLayers directionality dtype device
gru Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance
  ( Parameterized (GRU inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (GRU inputSize hiddenSize numLayers directionality dtype device))
      '[]
      ( Parameters (GRU inputSize hiddenSize numLayers directionality dtype device) ++ '[]
      )
  ) =>
  Parameterized (GRUWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)

instance
  ( Parameterized (GRU inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (GRU inputSize hiddenSize numLayers directionality dtype device))
      '[Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]]
      ( Parameters (GRU inputSize hiddenSize numLayers directionality dtype device)
          ++ '[Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]]
      )
  ) =>
  Parameterized (GRUWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)

instance
  ( KnownNat hiddenSize,
    KnownNat numLayers,
    KnownNat (NumberOfDirections directionality),
    KnownDType dtype,
    KnownDevice device,
    A.Randomizable
      (GRUSpec inputSize hiddenSize numLayers directionality dtype device)
      (GRU inputSize hiddenSize numLayers directionality dtype device)
  ) =>
  A.Randomizable
    (GRUWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
    (GRUWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
  where
  sample :: GRUWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> IO
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
sample (GRUWithZerosInitSpec GRUSpec inputSize hiddenSize numLayers directionality dtype device
gruSpec) =
    GRU inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
GRUWithConstInit
      (GRU inputSize hiddenSize numLayers directionality dtype device
 -> Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> GRUWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device)
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> GRUWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'ConstantInitialization
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample GRUSpec inputSize hiddenSize numLayers directionality dtype device
gruSpec
      IO
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
  sample (GRUWithConstInitSpec GRUSpec inputSize hiddenSize numLayers directionality dtype device
gruSpec Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
    GRU inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
GRUWithConstInit
      (GRU inputSize hiddenSize numLayers directionality dtype device
 -> Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> GRUWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device)
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> GRUWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'ConstantInitialization
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample GRUSpec inputSize hiddenSize numLayers directionality dtype device
gruSpec
      IO
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance
  ( KnownNat hiddenSize,
    KnownNat numLayers,
    KnownNat (NumberOfDirections directionality),
    KnownDType dtype,
    KnownDevice device,
    A.Randomizable
      (GRUSpec inputSize hiddenSize numLayers directionality dtype device)
      (GRU inputSize hiddenSize numLayers directionality dtype device)
  ) =>
  A.Randomizable
    (GRUWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
    (GRUWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
  where
  sample :: GRUWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> IO
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
sample s :: GRUWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
s@(GRUWithLearnedInitSpec GRUSpec inputSize hiddenSize numLayers directionality dtype device
gruSpec Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
    GRU inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
GRU inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
GRUWithLearnedInit
      (GRU inputSize hiddenSize numLayers directionality dtype device
 -> Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> GRUWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'LearnedInitialization
      dtype
      device)
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> GRUWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'LearnedInitialization
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> GRUSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (GRU inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample GRUSpec inputSize hiddenSize numLayers directionality dtype device
gruSpec
      IO
  (Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (GRUWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> IO
      (Parameter
         device
         dtype
         '[numLayers * NumberOfDirections directionality, hiddenSize]))
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h)

gruForward ::
  forall
    shapeOrder
    batchSize
    seqLen
    directionality
    initialization
    numLayers
    inputSize
    outputSize
    hiddenSize
    inputShape
    outputShape
    hcShape
    parameters
    tensorParameters
    dtype
    device.
  ( KnownNat (NumberOfDirections directionality),
    KnownNat numLayers,
    KnownNat batchSize,
    KnownNat hiddenSize,
    KnownRNNShapeOrder shapeOrder,
    KnownRNNDirectionality directionality,
    outputSize ~ (hiddenSize * NumberOfDirections directionality),
    inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
    outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
    hcShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
    parameters ~ Parameters (GRU inputSize hiddenSize numLayers directionality dtype device),
    Parameterized (GRU inputSize hiddenSize numLayers directionality dtype device),
    tensorParameters ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
    ATen.Castable (HList tensorParameters) [D.ATenTensor],
    HMap' ToDependent parameters tensorParameters
  ) =>
  Bool ->
  GRUWithInit
    inputSize
    hiddenSize
    numLayers
    directionality
    initialization
    dtype
    device ->
  Tensor device dtype inputShape ->
  ( Tensor device dtype outputShape,
    Tensor device dtype hcShape
  )
gruForward :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hcShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (GRU inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (GRU inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gruForward Bool
dropoutOn (GRUWithConstInit gruModel :: GRU inputSize hiddenSize numLayers directionality dtype device
gruModel@(GRU GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
  forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hcShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> Tensor device dtype hcShape
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
forall (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hcShape :: [Natural]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> Tensor device dtype hcShape
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gru
    @shapeOrder
    @directionality
    @numLayers
    @seqLen
    @batchSize
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hcShape
    @tensorParameters
    @dtype
    @device
    (ToDependent -> HList parameters -> HList tensorParameters
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList parameters -> HList tensorParameters)
-> (GRU inputSize hiddenSize numLayers directionality dtype device
    -> HList parameters)
-> GRU inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRU inputSize hiddenSize numLayers directionality dtype device
-> HList parameters
GRU inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (GRU inputSize hiddenSize numLayers directionality dtype device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (GRU inputSize hiddenSize numLayers directionality dtype device
 -> HList tensorParameters)
-> GRU inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall a b. (a -> b) -> a -> b
$ GRU inputSize hiddenSize numLayers directionality dtype device
gruModel)
    Double
dropoutProb
    Bool
dropoutOn
    Tensor device dtype hcShape
hc'
    Tensor device dtype inputShape
input
  where
    hc' :: Tensor device dtype hcShape
hc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hcShape
        (Tensor
   device
   dtype
   '[batchSize, numLayers * NumberOfDirections directionality,
     hiddenSize]
 -> Tensor device dtype hcShape)
-> (Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[batchSize, numLayers * NumberOfDirections directionality,
           hiddenSize])
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hcShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor device dtype hcShape)
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hcShape
forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc
gruForward Bool
dropoutOn (GRUWithLearnedInit gruModel :: GRU inputSize hiddenSize numLayers directionality dtype device
gruModel@(GRU GRULayerStack
  inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
  forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hcShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> Tensor device dtype hcShape
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
forall (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hcShape :: [Natural]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> Tensor device dtype hcShape
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gru
    @shapeOrder
    @directionality
    @numLayers
    @seqLen
    @batchSize
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hcShape
    @tensorParameters
    @dtype
    @device
    (ToDependent -> HList parameters -> HList tensorParameters
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList parameters -> HList tensorParameters)
-> (GRU inputSize hiddenSize numLayers directionality dtype device
    -> HList parameters)
-> GRU inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRU inputSize hiddenSize numLayers directionality dtype device
-> HList parameters
GRU inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (GRU inputSize hiddenSize numLayers directionality dtype device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (GRU inputSize hiddenSize numLayers directionality dtype device
 -> HList tensorParameters)
-> GRU inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall a b. (a -> b) -> a -> b
$ GRU inputSize hiddenSize numLayers directionality dtype device
gruModel)
    Double
dropoutProb
    Bool
dropoutOn
    Tensor device dtype hcShape
hc'
    Tensor device dtype inputShape
input
  where
    hc' :: Tensor device dtype hcShape
hc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hcShape
        (Tensor
   device
   dtype
   '[batchSize, numLayers * NumberOfDirections directionality,
     hiddenSize]
 -> Tensor device dtype hcShape)
-> (Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[batchSize, numLayers * NumberOfDirections directionality,
           hiddenSize])
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hcShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor
      device
      dtype
      '[batchSize, numLayers * NumberOfDirections directionality,
        hiddenSize])
-> (Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[numLayers * NumberOfDirections directionality, hiddenSize])
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[batchSize, numLayers * NumberOfDirections directionality,
       hiddenSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
        (Parameter
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor device dtype hcShape)
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hcShape
forall a b. (a -> b) -> a -> b
$ Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc

gruForwardWithDropout,
  gruForwardWithoutDropout ::
    forall
      shapeOrder
      batchSize
      seqLen
      directionality
      initialization
      numLayers
      inputSize
      outputSize
      hiddenSize
      inputShape
      outputShape
      hcShape
      parameters
      tensorParameters
      dtype
      device.
    ( KnownNat (NumberOfDirections directionality),
      KnownNat numLayers,
      KnownNat batchSize,
      KnownNat hiddenSize,
      KnownRNNShapeOrder shapeOrder,
      KnownRNNDirectionality directionality,
      outputSize ~ (hiddenSize * NumberOfDirections directionality),
      inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
      outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
      hcShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
      parameters ~ Parameters (GRU inputSize hiddenSize numLayers directionality dtype device),
      Parameterized (GRU inputSize hiddenSize numLayers directionality dtype device),
      tensorParameters ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
      ATen.Castable (HList tensorParameters) [D.ATenTensor],
      HMap' ToDependent parameters tensorParameters
    ) =>
    GRUWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      initialization
      dtype
      device ->
    Tensor device dtype inputShape ->
    ( Tensor device dtype outputShape,
      Tensor device dtype hcShape
    )
-- ^ Forward propagate the `GRU` module and apply dropout on the outputs of each layer.
--
-- >>> input :: CPUTensor 'D.Float '[5,16,10] <- randn
-- >>> spec = GRUWithZerosInitSpec @10 @30 @3 @'Bidirectional @'D.Float @'( 'D.CPU, 0) (GRUSpec (DropoutSpec 0.5))
-- >>> model <- A.sample spec
-- >>> :t gruForwardWithDropout @'BatchFirst model input
-- gruForwardWithDropout @'BatchFirst model input
--   :: (Tensor '(D.CPU, 0) 'D.Float [5, 16, 60],
--       Tensor '(D.CPU, 0) 'D.Float [6, 5, 30])
-- >>> (a,b) = gruForwardWithDropout @'BatchFirst model input
-- >>> ((dtype a, shape a), (dtype b, shape b))
-- ((Float,[5,16,60]),(Float,[6,5,30]))
gruForwardWithDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hcShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (GRU inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (GRU inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gruForwardWithDropout =
  forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hcShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (GRU inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (GRU inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gruForward
    @shapeOrder
    @batchSize
    @seqLen
    @directionality
    @initialization
    @numLayers
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hcShape
    @parameters
    @tensorParameters
    @dtype
    @device
    Bool
True
-- ^ Forward propagate the `GRU` module (without applying dropout on the outputs of each layer).
--
-- >>> input :: CPUTensor 'D.Float '[5,16,10] <- randn
-- >>> spec = GRUWithZerosInitSpec @10 @30 @3 @'Bidirectional @'D.Float @'( 'D.CPU, 0) (GRUSpec (DropoutSpec 0.5))
-- >>> model <- A.sample spec
-- >>> :t gruForwardWithoutDropout @'BatchFirst model input
-- gruForwardWithoutDropout @'BatchFirst model input
--   :: (Tensor '(D.CPU, 0) 'D.Float [5, 16, 60],
--       Tensor '(D.CPU, 0) 'D.Float [6, 5, 30])
-- >>> (a,b) = gruForwardWithoutDropout @'BatchFirst model input
-- >>> ((dtype a, shape a), (dtype b, shape b))
-- ((Float,[5,16,60]),(Float,[6,5,30]))
gruForwardWithoutDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hcShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (GRU inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (GRU inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
GRUWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gruForwardWithoutDropout =
  forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hcShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hcShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (GRU inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (GRU inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ GRUR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> GRUWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hcShape)
gruForward
    @shapeOrder
    @batchSize
    @seqLen
    @directionality
    @initialization
    @numLayers
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hcShape
    @parameters
    @tensorParameters
    @dtype
    @device
    Bool
False