{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fno-warn-partial-type-signatures #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module Torch.Typed.NN.Recurrent.LSTM where

import Data.Kind
import Data.Proxy (Proxy (..))
import Foreign.ForeignPtr
import GHC.Generics
import GHC.TypeLits
import GHC.TypeLits.Extra
import System.Environment
import System.IO.Unsafe
import qualified Torch.Autograd as A
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.Functional as D
import Torch.HList
import qualified Torch.Internal.Cast as ATen
import qualified Torch.Internal.Class as ATen
import qualified Torch.Internal.Managed.Type.Tensor as ATen
import qualified Torch.Internal.Type as ATen
import qualified Torch.NN as A
import qualified Torch.Tensor as D
import qualified Torch.TensorFactories as D
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (sqrt)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Recurrent.Auxiliary
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (tanh)

data
  LSTMLayerSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LSTMLayerSpec
  deriving (Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
(Int
 -> LSTMLayerSpec inputSize hiddenSize directionality dtype device
 -> ShowS)
-> (LSTMLayerSpec inputSize hiddenSize directionality dtype device
    -> String)
-> ([LSTMLayerSpec
       inputSize hiddenSize directionality dtype device]
    -> ShowS)
-> Show
     (LSTMLayerSpec inputSize hiddenSize directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
showsPrec :: Int
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> ShowS
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
show :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> String
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
[LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
showList :: [LSTMLayerSpec inputSize hiddenSize directionality dtype device]
-> ShowS
Show, LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
(LSTMLayerSpec inputSize hiddenSize directionality dtype device
 -> LSTMLayerSpec inputSize hiddenSize directionality dtype device
 -> Bool)
-> (LSTMLayerSpec inputSize hiddenSize directionality dtype device
    -> LSTMLayerSpec inputSize hiddenSize directionality dtype device
    -> Bool)
-> Eq
     (LSTMLayerSpec inputSize hiddenSize directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
== :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
/= :: LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> Bool
Eq)

data
  LSTMLayer
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTMUnidirectionalLayer ::
    Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
    LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
  LSTMBidirectionalLayer ::
    Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMWHShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBIShape hiddenSize inputSize) ->
    Parameter device dtype (LSTMBHShape hiddenSize inputSize) ->
    LSTMLayer inputSize hiddenSize 'Bidirectional dtype device

deriving instance Show (LSTMLayer inputSize hiddenSize directionality dtype device)

instance Parameterized (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device) where
  type
    Parameters (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device) =
      '[ Parameter device dtype (LSTMWIShape hiddenSize inputSize),
         Parameter device dtype (LSTMWHShape hiddenSize inputSize),
         Parameter device dtype (LSTMBIShape hiddenSize inputSize),
         Parameter device dtype (LSTMBHShape hiddenSize inputSize)
       ]
  flattenParameters :: LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device))
flattenParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh) =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> HList '[]
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil
  replaceParameters :: LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device))
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
replaceParameters LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
_ (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh :. HList '[]
R:HListk[] Type
HNil) =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh

instance Parameterized (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device) where
  type
    Parameters (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device) =
      '[ Parameter device dtype (LSTMWIShape hiddenSize inputSize),
         Parameter device dtype (LSTMWHShape hiddenSize inputSize),
         Parameter device dtype (LSTMBIShape hiddenSize inputSize),
         Parameter device dtype (LSTMBHShape hiddenSize inputSize),
         Parameter device dtype (LSTMWIShape hiddenSize inputSize),
         Parameter device dtype (LSTMWHShape hiddenSize inputSize),
         Parameter device dtype (LSTMBIShape hiddenSize inputSize),
         Parameter device dtype (LSTMBHShape hiddenSize inputSize)
       ]
  flattenParameters :: LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device))
flattenParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh') =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMWIShape hiddenSize inputSize),
       Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMWHShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize),
       Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> HList '[]
-> HList
     '[Parameter device dtype (LSTMBIShape hiddenSize inputSize)]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil
  replaceParameters :: LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device))
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
replaceParameters LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
_ (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh :. Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' :. Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' :. Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh' :. HList '[]
R:HListk[] Type
HNil) =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh'

instance
  ( RandDTypeIsValid device dtype,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownDType dtype,
    KnownDevice device
  ) =>
  A.Randomizable
    (LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device)
    (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
  where
  sample :: LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device
-> IO (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
sample LSTMLayerSpec inputSize hiddenSize 'Unidirectional dtype device
_ =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype '[4 * hiddenSize]
-> Parameter device dtype '[4 * hiddenSize]
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer
      (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
 -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
 -> Parameter device dtype '[4 * hiddenSize]
 -> Parameter device dtype '[4 * hiddenSize]
 -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
-> IO
     (Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype (LSTMWIShape hiddenSize inputSize)
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (LSTMWIShape hiddenSize inputSize)
 -> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize)))
-> IO (Tensor device dtype (LSTMWIShape hiddenSize inputSize))
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (LSTMWIShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      IO
  (Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
-> IO
     (Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (LSTMWHShape hiddenSize inputSize)
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (LSTMWHShape hiddenSize inputSize)
 -> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize)))
-> IO (Tensor device dtype (LSTMWHShape hiddenSize inputSize))
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (LSTMWHShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      IO
  (Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype '[4 * hiddenSize])
-> IO
     (Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenSize]
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenSize]
 -> IO (Parameter device dtype '[4 * hiddenSize]))
-> IO (Tensor device dtype '[4 * hiddenSize])
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[4 * hiddenSize]
-> IO (Tensor device dtype '[4 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[4 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
-> IO (Parameter device dtype '[4 * hiddenSize])
-> IO (LSTMLayer inputSize hiddenSize 'Unidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenSize]
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenSize]
 -> IO (Parameter device dtype '[4 * hiddenSize]))
-> IO (Tensor device dtype '[4 * hiddenSize])
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[4 * hiddenSize]
-> IO (Tensor device dtype '[4 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[4 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)

instance
  ( RandDTypeIsValid device dtype,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownDType dtype,
    KnownDevice device
  ) =>
  A.Randomizable
    (LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device)
    (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
  where
  sample :: LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device
-> IO (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
sample LSTMLayerSpec inputSize hiddenSize 'Bidirectional dtype device
_ =
    Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype '[4 * hiddenSize]
-> Parameter device dtype '[4 * hiddenSize]
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype '[4 * hiddenSize]
-> Parameter device dtype '[4 * hiddenSize]
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer
      (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
 -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
 -> Parameter device dtype '[4 * hiddenSize]
 -> Parameter device dtype '[4 * hiddenSize]
 -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
 -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
 -> Parameter device dtype '[4 * hiddenSize]
 -> Parameter device dtype '[4 * hiddenSize]
 -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
-> IO
     (Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
      -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype (LSTMWIShape hiddenSize inputSize)
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (LSTMWIShape hiddenSize inputSize)
 -> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize)))
-> IO (Tensor device dtype (LSTMWIShape hiddenSize inputSize))
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (LSTMWIShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      IO
  (Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
   -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
-> IO
     (Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
      -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (LSTMWHShape hiddenSize inputSize)
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (LSTMWHShape hiddenSize inputSize)
 -> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize)))
-> IO (Tensor device dtype (LSTMWHShape hiddenSize inputSize))
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (LSTMWHShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      IO
  (Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
   -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[4 * hiddenSize])
-> IO
     (Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
      -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenSize]
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenSize]
 -> IO (Parameter device dtype '[4 * hiddenSize]))
-> IO (Tensor device dtype '[4 * hiddenSize])
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[4 * hiddenSize]
-> IO (Tensor device dtype '[4 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[4 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
   -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[4 * hiddenSize])
-> IO
     (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
      -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenSize]
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenSize]
 -> IO (Parameter device dtype '[4 * hiddenSize]))
-> IO (Tensor device dtype '[4 * hiddenSize])
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[4 * hiddenSize]
-> IO (Tensor device dtype '[4 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[4 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype (LSTMWIShape hiddenSize inputSize)
   -> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
-> IO
     (Parameter device dtype (LSTMWHShape hiddenSize inputSize)
      -> Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (LSTMWIShape hiddenSize inputSize)
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (LSTMWIShape hiddenSize inputSize)
 -> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize)))
-> IO (Tensor device dtype (LSTMWIShape hiddenSize inputSize))
-> IO (Parameter device dtype (LSTMWIShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (LSTMWIShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      IO
  (Parameter device dtype (LSTMWHShape hiddenSize inputSize)
   -> Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
-> IO
     (Parameter device dtype '[4 * hiddenSize]
      -> Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype (LSTMWHShape hiddenSize inputSize)
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype (LSTMWHShape hiddenSize inputSize)
 -> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize)))
-> IO (Tensor device dtype (LSTMWHShape hiddenSize inputSize))
-> IO (Parameter device dtype (LSTMWHShape hiddenSize inputSize))
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype (LSTMWHShape hiddenSize inputSize))
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM)
      IO
  (Parameter device dtype '[4 * hiddenSize]
   -> Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[4 * hiddenSize])
-> IO
     (Parameter device dtype '[4 * hiddenSize]
      -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenSize]
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenSize]
 -> IO (Parameter device dtype '[4 * hiddenSize]))
-> IO (Tensor device dtype '[4 * hiddenSize])
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[4 * hiddenSize]
-> IO (Tensor device dtype '[4 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[4 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)
      IO
  (Parameter device dtype '[4 * hiddenSize]
   -> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
-> IO (Parameter device dtype '[4 * hiddenSize])
-> IO (LSTMLayer inputSize hiddenSize 'Bidirectional dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenSize]
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenSize]
 -> IO (Parameter device dtype '[4 * hiddenSize]))
-> IO (Tensor device dtype '[4 * hiddenSize])
-> IO (Parameter device dtype '[4 * hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[4 * hiddenSize]
-> IO (Tensor device dtype '[4 * hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[4 * hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros)

instance A.Parameterized (LSTMLayer inputSize hiddenSize directionality dtype device) where
  flattenParameters :: LSTMLayer inputSize hiddenSize directionality dtype device
-> [Parameter]
flattenParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh) =
    [ Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi,
      Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh,
      Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi,
      Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh
    ]
  flattenParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh') =
    [ Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi,
      Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh,
      Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi,
      Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh,
      Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWIShape hiddenSize inputSize)
wi',
      Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMWHShape hiddenSize inputSize)
wh',
      Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bi',
      Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam Parameter device dtype (LSTMBIShape hiddenSize inputSize)
bh'
    ]
  _replaceParameters :: LSTMLayer inputSize hiddenSize directionality dtype device
-> ParamStream
     (LSTMLayer inputSize hiddenSize directionality dtype device)
_replaceParameters (LSTMUnidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh) = do
    Parameter
wi <- ParamStream Parameter
A.nextParameter
    Parameter
wh <- ParamStream Parameter
A.nextParameter
    Parameter
bi <- ParamStream Parameter
A.nextParameter
    Parameter
bh <- ParamStream Parameter
A.nextParameter
    LSTMLayer inputSize hiddenSize directionality dtype device
-> ParamStream
     (LSTMLayer inputSize hiddenSize directionality dtype device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return
      ( Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Unidirectional dtype device
LSTMUnidirectionalLayer
          (Parameter
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi)
          (Parameter
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh)
          (Parameter
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi)
          (Parameter
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh)
      )
  _replaceParameters (LSTMBidirectionalLayer Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh Parameter device dtype (LSTMWIShape hiddenSize inputSize)
_wi' Parameter device dtype (LSTMWHShape hiddenSize inputSize)
_wh' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bi' Parameter device dtype (LSTMBIShape hiddenSize inputSize)
_bh') = do
    Parameter
wi <- ParamStream Parameter
A.nextParameter
    Parameter
wh <- ParamStream Parameter
A.nextParameter
    Parameter
bi <- ParamStream Parameter
A.nextParameter
    Parameter
bh <- ParamStream Parameter
A.nextParameter
    Parameter
wi' <- ParamStream Parameter
A.nextParameter
    Parameter
wh' <- ParamStream Parameter
A.nextParameter
    Parameter
bi' <- ParamStream Parameter
A.nextParameter
    Parameter
bh' <- ParamStream Parameter
A.nextParameter
    LSTMLayer inputSize hiddenSize directionality dtype device
-> ParamStream
     (LSTMLayer inputSize hiddenSize directionality dtype device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return
      ( Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (inputSize :: Natural).
Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
-> LSTMLayer inputSize hiddenSize 'Bidirectional dtype device
LSTMBidirectionalLayer
          (Parameter
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi)
          (Parameter
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh)
          (Parameter
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi)
          (Parameter
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh)
          (Parameter
-> Parameter device dtype (LSTMWIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wi')
          (Parameter
-> Parameter device dtype (LSTMWHShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
wh')
          (Parameter
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bi')
          (Parameter
-> Parameter device dtype (LSTMBIShape hiddenSize inputSize)
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
bh')
      )

data
  LSTMLayerStackSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LSTMLayerStackSpec
  deriving (Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
[LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
(Int
 -> LSTMLayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
 -> ShowS)
-> (LSTMLayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
    -> String)
-> ([LSTMLayerStackSpec
       inputSize hiddenSize numLayers directionality dtype device]
    -> ShowS)
-> Show
     (LSTMLayerStackSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
showsPrec :: Int
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
show :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> String
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
showList :: [LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
Show, LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
(LSTMLayerStackSpec
   inputSize hiddenSize numLayers directionality dtype device
 -> LSTMLayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
 -> Bool)
-> (LSTMLayerStackSpec
      inputSize hiddenSize numLayers directionality dtype device
    -> LSTMLayerStackSpec
         inputSize hiddenSize numLayers directionality dtype device
    -> Bool)
-> Eq
     (LSTMLayerStackSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
== :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
$c/= :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
/= :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> Bool
Eq)

-- Input-to-hidden, hidden-to-hidden, and bias parameters for a mulilayered
-- (and optionally) bidirectional LSTM.
--
data
  LSTMLayerStack
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTMLayer1 ::
    LSTMLayer inputSize hiddenSize directionality dtype device ->
    LSTMLayerStack inputSize hiddenSize 1 directionality dtype device
  LSTMLayerK ::
    LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
    LSTMLayerStack inputSize hiddenSize (numLayers + 1) directionality dtype device

deriving instance Show (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)

class LSTMLayerStackParameterized (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
  type LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device :: [Type]
  lstmLayerStackFlattenParameters ::
    Proxy flag ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
    HList (LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device)
  lstmLayerStackReplaceParameters ::
    Proxy flag ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device ->
    HList (LSTMLayerStackParameters flag inputSize hiddenSize numLayers directionality dtype device) ->
    LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device

instance
  Parameterized (LSTMLayer inputSize hiddenSize directionality dtype device) =>
  LSTMLayerStackParameterized 'False inputSize hiddenSize 1 directionality dtype device
  where
  type
    LSTMLayerStackParameters 'False inputSize hiddenSize 1 directionality dtype device =
      Parameters (LSTMLayer inputSize hiddenSize directionality dtype device)
  lstmLayerStackFlattenParameters :: Proxy 'False
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'False inputSize hiddenSize 1 directionality dtype device)
lstmLayerStackFlattenParameters Proxy 'False
_ (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer) = LSTMLayer inputSize hiddenSize directionality dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize directionality dtype device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer
  lstmLayerStackReplaceParameters :: Proxy 'False
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'False inputSize hiddenSize 1 directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
lstmLayerStackReplaceParameters Proxy 'False
_ (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer) HList
  (LSTMLayerStackParameters
     'False inputSize hiddenSize 1 directionality dtype device)
parameters = LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 (LSTMLayer inputSize hiddenSize directionality dtype device
 -> LSTMLayerStack
      inputSize hiddenSize 1 directionality dtype device)
-> LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
forall a b. (a -> b) -> a -> b
$ LSTMLayer inputSize hiddenSize directionality dtype device
-> HList
     (Parameters
        (LSTMLayer inputSize hiddenSize directionality dtype device))
-> LSTMLayer inputSize hiddenSize directionality dtype device
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters LSTMLayer inputSize hiddenSize directionality dtype device
lstmLayer HList
  (Parameters
     (LSTMLayer inputSize hiddenSize directionality dtype device))
HList
  (LSTMLayerStackParameters
     'False inputSize hiddenSize 1 directionality dtype device)
parameters

instance
  ( Parameterized
      ( LSTMLayer
          (hiddenSize * NumberOfDirections directionality)
          hiddenSize
          directionality
          dtype
          device
      ),
    Parameterized (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device),
    HAppendFD
      (Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device))
      (Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device))
      ( Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
          ++ Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
      ),
    1 <= numLayers,
    numLayersM1 ~ numLayers - 1,
    0 <= numLayersM1
  ) =>
  LSTMLayerStackParameterized 'True inputSize hiddenSize numLayers directionality dtype device
  where
  type
    LSTMLayerStackParameters 'True inputSize hiddenSize numLayers directionality dtype device =
      Parameters (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
        ++ Parameters (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
  lstmLayerStackFlattenParameters :: Proxy 'True
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'True inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackFlattenParameters Proxy 'True
_ (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack) =
    let parameters :: HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters = LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> HList
     (Parameters
        (LSTMLayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer
        parameters' :: HList
  (Parameters
     (LSTMLayerStack
        inputSize hiddenSize numLayersM1 directionality dtype device))
parameters' = forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters @(LSTMLayerStack inputSize hiddenSize numLayersM1 directionality dtype device) LSTMLayerStack
  inputSize hiddenSize numLayersM1 directionality dtype device
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack
     in HList
  (Parameters
     (LSTMLayerStack
        inputSize hiddenSize numLayersM1 directionality dtype device))
HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
parameters' HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
-> HList
     (Parameters
        (LSTMLayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
-> HList
     (LSTMLayerStackParameters
        (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
        inputSize
        hiddenSize
        numLayersM1
        directionality
        dtype
        device
      ++ Parameters
           (LSTMLayer
              (hiddenSize * NumberOfDirections directionality)
              hiddenSize
              directionality
              dtype
              device))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList a -> HList b -> HList ab
`happendFD` HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters
  lstmLayerStackReplaceParameters :: Proxy 'True
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        'True inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstmLayerStackReplaceParameters Proxy 'True
_ (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstmLayerStack) HList
  (LSTMLayerStackParameters
     'True inputSize hiddenSize numLayers directionality dtype device)
parameters'' =
    let (HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
parameters', HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters) = HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device
   ++ Parameters
        (LSTMLayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
-> (HList
      (LSTMLayerStackParameters
         (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
         inputSize
         hiddenSize
         numLayersM1
         directionality
         dtype
         device),
    HList
      (Parameters
         (LSTMLayer
            (hiddenSize * NumberOfDirections directionality)
            hiddenSize
            directionality
            dtype
            device)))
forall k (a :: [k]) (b :: [k]) (ab :: [k]).
HAppendFD a b ab =>
HList ab -> (HList a, HList b)
hunappendFD HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device
   ++ Parameters
        (LSTMLayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
HList
  (LSTMLayerStackParameters
     'True inputSize hiddenSize numLayers directionality dtype device)
parameters''
        lstmLayer' :: LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer' = LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> HList
     (Parameters
        (LSTMLayer
           (hiddenSize * NumberOfDirections directionality)
           hiddenSize
           directionality
           dtype
           device))
-> LSTMLayer
     (hiddenSize * NumberOfDirections directionality)
     hiddenSize
     directionality
     dtype
     device
forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer HList
  (Parameters
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device))
parameters
        lstmLayerStack' :: LSTMLayerStack
  inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack' = forall f. Parameterized f => f -> HList (Parameters f) -> f
replaceParameters @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device) LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStack
  inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack HList
  (Parameters
     (LSTMLayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device))
HList
  (LSTMLayerStackParameters
     (OrdCond (CmpNat 2 numLayersM1) 'True 'True 'False)
     inputSize
     hiddenSize
     numLayersM1
     directionality
     dtype
     device)
parameters'
     in LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
lstmLayer' LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStack
  inputSize hiddenSize (numLayers - 1) directionality dtype device
lstmLayerStack'

instance
  ( 1 <= numLayers,
    (2 <=? numLayers) ~ flag,
    LSTMLayerStackParameterized flag inputSize hiddenSize numLayers directionality dtype device
  ) =>
  Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
  where
  type
    Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device) =
      LSTMLayerStackParameters (2 <=? numLayers) inputSize hiddenSize numLayers directionality dtype device
  flattenParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (LSTMLayerStack
           inputSize hiddenSize numLayers directionality dtype device))
flattenParameters = Proxy flag
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerStackParameterized
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackFlattenParameters (Proxy flag
forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)
  replaceParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (LSTMLayerStack
           inputSize hiddenSize numLayers directionality dtype device))
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
replaceParameters = Proxy flag
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerStackParameterized
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> HList
     (LSTMLayerStackParameters
        flag inputSize hiddenSize numLayers directionality dtype device)
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstmLayerStackReplaceParameters (Proxy flag
forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)

class LSTMLayerStackRandomizable (flag :: Bool) inputSize hiddenSize numLayers directionality dtype device where
  lstmLayerStackSample ::
    Proxy flag ->
    LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device ->
    IO (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)

instance
  ( A.Randomizable
      (LSTMLayerSpec inputSize hiddenSize directionality dtype device)
      (LSTMLayer inputSize hiddenSize directionality dtype device)
  ) =>
  LSTMLayerStackRandomizable 'False inputSize hiddenSize 1 directionality dtype device
  where
  lstmLayerStackSample :: Proxy 'False
-> LSTMLayerStackSpec
     inputSize hiddenSize 1 directionality dtype device
-> IO
     (LSTMLayerStack inputSize hiddenSize 1 directionality dtype device)
lstmLayerStackSample Proxy 'False
_ LSTMLayerStackSpec
  inputSize hiddenSize 1 directionality dtype device
_ = LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 (LSTMLayer inputSize hiddenSize directionality dtype device
 -> LSTMLayerStack
      inputSize hiddenSize 1 directionality dtype device)
-> IO (LSTMLayer inputSize hiddenSize directionality dtype device)
-> IO
     (LSTMLayerStack inputSize hiddenSize 1 directionality dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> IO (LSTMLayer inputSize hiddenSize directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
sample (LSTMLayerSpec inputSize hiddenSize directionality dtype device
 -> IO (LSTMLayer inputSize hiddenSize directionality dtype device))
-> LSTMLayerSpec inputSize hiddenSize directionality dtype device
-> IO (LSTMLayer inputSize hiddenSize directionality dtype device)
forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
LSTMLayerSpec @inputSize @hiddenSize @directionality @dtype @device)

instance
  ( 1 <= numLayers,
    A.Randomizable
      (LSTMLayerSpec (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device)
      (LSTMLayer (hiddenSize * NumberOfDirections directionality) hiddenSize directionality dtype device),
    A.Randomizable
      (LSTMLayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
      (LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
  ) =>
  LSTMLayerStackRandomizable 'True inputSize hiddenSize numLayers directionality dtype device
  where
  lstmLayerStackSample :: Proxy 'True
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackSample Proxy 'True
_ LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
_ =
    LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers - 1) directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers - 1) directionality dtype device
-> LSTMLayerStack
     inputSize
     hiddenSize
     ((numLayers - 1) + 1)
     directionality
     dtype
     device
forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK
      (LSTMLayer
   (hiddenSize * NumberOfDirections directionality)
   hiddenSize
   directionality
   dtype
   device
 -> LSTMLayerStack
      inputSize hiddenSize (numLayers - 1) directionality dtype device
 -> LSTMLayerStack
      inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
-> IO
     (LSTMLayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device
      -> LSTMLayerStack
           inputSize hiddenSize numLayers directionality dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (LSTMLayerSpec
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> IO
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
forall spec f. Randomizable spec f => spec -> IO f
sample (LSTMLayerSpec
   (hiddenSize * NumberOfDirections directionality)
   hiddenSize
   directionality
   dtype
   device
 -> IO
      (LSTMLayer
         (hiddenSize * NumberOfDirections directionality)
         hiddenSize
         directionality
         dtype
         device))
-> LSTMLayerSpec
     (hiddenSize * NumberOfDirections directionality)
     hiddenSize
     directionality
     dtype
     device
-> IO
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
forall a b. (a -> b) -> a -> b
$ forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerSpec inputSize hiddenSize directionality dtype device
LSTMLayerSpec @(hiddenSize * NumberOfDirections directionality) @hiddenSize @directionality @dtype @device)
      IO
  (LSTMLayerStack
     inputSize hiddenSize (numLayers - 1) directionality dtype device
   -> LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (LSTMLayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device)
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> ( forall spec f. Randomizable spec f => spec -> IO f
sample
              @(LSTMLayerStackSpec inputSize hiddenSize (numLayers - 1) directionality dtype device)
              @(LSTMLayerStack inputSize hiddenSize (numLayers - 1) directionality dtype device)
              (LSTMLayerStackSpec
   inputSize hiddenSize (numLayers - 1) directionality dtype device
 -> IO
      (LSTMLayerStack
         inputSize hiddenSize (numLayers - 1) directionality dtype device))
-> LSTMLayerStackSpec
     inputSize hiddenSize (numLayers - 1) directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize (numLayers - 1) directionality dtype device)
forall a b. (a -> b) -> a -> b
$ LSTMLayerStackSpec
  inputSize hiddenSize (numLayers - 1) directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStackSpec
          )

instance
  ( 1 <= numLayers,
    (2 <=? numLayers) ~ flag,
    RandDTypeIsValid device dtype,
    KnownDType dtype,
    KnownDevice device,
    LSTMLayerStackRandomizable flag inputSize hiddenSize numLayers directionality dtype device
  ) =>
  Randomizable
    (LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
    (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
  where
  sample :: LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
sample = Proxy flag
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall (flag :: Bool) (inputSize :: Natural)
       (hiddenSize :: Natural) (numLayers :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayerStackRandomizable
  flag inputSize hiddenSize numLayers directionality dtype device =>
Proxy flag
-> LSTMLayerStackSpec
     inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
lstmLayerStackSample (Proxy flag
forall {k} (t :: k). Proxy t
Proxy :: Proxy flag)

instance A.Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device) where
  flattenParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
flattenParameters (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer) =
    LSTMLayer inputSize hiddenSize directionality dtype device
-> [Parameter]
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayer inputSize hiddenSize directionality dtype device
layer
  flattenParameters (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer) =
    LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> [Parameter]
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack
      [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer
  _replaceParameters :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
_replaceParameters (LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer) = do
    LSTMLayer inputSize hiddenSize directionality dtype device
layer' <- LSTMLayer inputSize hiddenSize directionality dtype device
-> ParamStream
     (LSTMLayer inputSize hiddenSize directionality dtype device)
forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayer inputSize hiddenSize directionality dtype device
layer
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (LSTMLayerStack
   inputSize hiddenSize numLayers directionality dtype device
 -> ParamStream
      (LSTMLayerStack
         inputSize hiddenSize numLayers directionality dtype device))
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall a b. (a -> b) -> a -> b
$ LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)).
LSTMLayer inputSize hiddenSize directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize 1 directionality dtype device
LSTMLayer1 LSTMLayer inputSize hiddenSize directionality dtype device
layer'
  _replaceParameters (LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer) = do
    LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack' <- LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> ParamStream
     (LSTMLayer
        (hiddenSize * NumberOfDirections directionality)
        hiddenSize
        directionality
        dtype
        device)
forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer' <- LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (LSTMLayerStack
   inputSize hiddenSize numLayers directionality dtype device
 -> ParamStream
      (LSTMLayerStack
         inputSize hiddenSize numLayers directionality dtype device))
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall a b. (a -> b) -> a -> b
$ LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
forall (hiddenSize :: Natural)
       (directionality :: RNNDirectionality) (dtype :: DType)
       (device :: (DeviceType, Natural)) (inputSize :: Natural)
       (numLayers :: Natural).
LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize (numLayers + 1) directionality dtype device
LSTMLayerK LSTMLayer
  (hiddenSize * NumberOfDirections directionality)
  hiddenSize
  directionality
  dtype
  device
stack' LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layer'

newtype
  LSTMSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = LSTMSpec DropoutSpec
  deriving (Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
[LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
(Int
 -> LSTMSpec
      inputSize hiddenSize numLayers directionality dtype device
 -> ShowS)
-> (LSTMSpec
      inputSize hiddenSize numLayers directionality dtype device
    -> String)
-> ([LSTMSpec
       inputSize hiddenSize numLayers directionality dtype device]
    -> ShowS)
-> Show
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
showsPrec :: Int
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
-> ShowS
$cshow :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
show :: LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> String
$cshowList :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
[LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
showList :: [LSTMSpec
   inputSize hiddenSize numLayers directionality dtype device]
-> ShowS
Show, (forall x.
 LSTMSpec inputSize hiddenSize numLayers directionality dtype device
 -> Rep
      (LSTMSpec
         inputSize hiddenSize numLayers directionality dtype device)
      x)
-> (forall x.
    Rep
      (LSTMSpec
         inputSize hiddenSize numLayers directionality dtype device)
      x
    -> LSTMSpec
         inputSize hiddenSize numLayers directionality dtype device)
-> Generic
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
  (LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall x.
Rep
  (LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
forall x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
from :: forall x.
LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTMSpec
        inputSize hiddenSize numLayers directionality dtype device)
     x
$cto :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)) x.
Rep
  (LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
to :: forall x.
Rep
  (LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> LSTMSpec
     inputSize hiddenSize numLayers directionality dtype device
Generic)

data
  LSTM
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTM ::
    (1 <= numLayers) =>
    { forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack :: LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device,
      forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_dropout :: Dropout
    } ->
    LSTM inputSize hiddenSize numLayers directionality dtype device

deriving instance Show (LSTM inputSize hiddenSize numLayers directionality dtype device)

instance
  (1 <= numLayers) =>
  Generic (LSTM inputSize hiddenSize numLayers directionality dtype device)
  where
  type
    Rep (LSTM inputSize hiddenSize numLayers directionality dtype device) =
      Rec0 (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 Dropout
  from :: forall x.
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Rep
     (LSTM inputSize hiddenSize numLayers directionality dtype device) x
from (LSTM {Dropout
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
..}) = LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> K1
     R
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
     x
forall k i c (p :: k). c -> K1 i c p
K1 LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack K1
  R
  (LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device)
  x
-> K1 R Dropout x
-> (:*:)
     (K1
        R
        (LSTMLayerStack
           inputSize hiddenSize numLayers directionality dtype device))
     (K1 R Dropout)
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Dropout -> K1 R Dropout x
forall k i c (p :: k). c -> K1 i c p
K1 Dropout
lstm_dropout
  to :: forall x.
Rep
  (LSTM inputSize hiddenSize numLayers directionality dtype device) x
-> LSTM inputSize hiddenSize numLayers directionality dtype device
to (K1 LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layerStack :*: K1 Dropout
dropout) = LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
LSTM LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
layerStack Dropout
dropout

instance
  ( 1 <= numLayers,
    Parameterized (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device))
      (Parameters Dropout)
      ( Parameters (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device)
          ++ Parameters Dropout
      )
  ) =>
  Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device)

-- TODO: when we have cannonical initializers do this correctly:
-- https://github.com/pytorch/pytorch/issues/9221
-- https://discuss.pytorch.org/t/initializing-rnn-gru-and-lstm-correctly/23605

instance A.Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device) where
  flattenParameters :: LSTM inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
flattenParameters LSTM {Dropout
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
..} = LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack
  _replaceParameters :: LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
_replaceParameters LSTM {Dropout
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> LSTMLayerStack
     inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Dropout
lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_dropout :: Dropout
..} = do
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack' <- LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack
    LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (LSTM inputSize hiddenSize numLayers directionality dtype device
 -> ParamStream
      (LSTM inputSize hiddenSize numLayers directionality dtype device))
-> LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall a b. (a -> b) -> a -> b
$
      LSTM
        { lstm_layer_stack :: LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack = LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
lstm_layer_stack',
          Dropout
lstm_dropout :: Dropout
lstm_dropout :: Dropout
..
        }

-- | Helper to do xavier uniform initializations on weight matrices and
-- orthagonal initializations for the gates. (When implemented.)
xavierUniformLSTM ::
  forall device dtype hiddenSize featureSize.
  ( KnownDType dtype,
    KnownNat hiddenSize,
    KnownNat featureSize,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM :: forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (hiddenSize :: Natural) (featureSize :: Natural).
(KnownDType dtype, KnownNat hiddenSize, KnownNat featureSize,
 KnownDevice device, RandDTypeIsValid device dtype) =>
IO (Tensor device dtype '[4 * hiddenSize, featureSize])
xavierUniformLSTM = do
  Tensor device dtype '[4 * hiddenSize, featureSize]
init <- IO (Tensor device dtype '[4 * hiddenSize, featureSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn :: IO (Tensor device dtype '[4 * hiddenSize, featureSize])
  Tensor -> Tensor device dtype '[4 * hiddenSize, featureSize]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Tensor -> Tensor device dtype shape
UnsafeMkTensor
    (Tensor -> Tensor device dtype '[4 * hiddenSize, featureSize])
-> IO Tensor
-> IO (Tensor device dtype '[4 * hiddenSize, featureSize])
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Tensor -> Float -> [Int] -> IO Tensor
xavierUniformFIXME
      (Tensor device dtype '[4 * hiddenSize, featureSize] -> Tensor
forall t. Unnamed t => t -> Tensor
toDynamic Tensor device dtype '[4 * hiddenSize, featureSize]
init)
      (Float
5.0 Float -> Float -> Float
forall a. Fractional a => a -> a -> a
/ Float
3)
      (forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]) t.
(TensorOptions shape dtype device,
 IsUnnamed t device dtype shape) =>
t -> [Int]
shape @device @dtype @'[4 * hiddenSize, featureSize] Tensor device dtype '[4 * hiddenSize, featureSize]
init)

instance
  ( KnownDType dtype,
    KnownDevice device,
    KnownNat inputSize,
    KnownNat hiddenSize,
    KnownNat (NumberOfDirections directionality),
    RandDTypeIsValid device dtype,
    A.Randomizable
      (LSTMLayerStackSpec inputSize hiddenSize numLayers directionality dtype device)
      (LSTMLayerStack inputSize hiddenSize numLayers directionality dtype device),
    1 <= numLayers
  ) =>
  A.Randomizable
    (LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
    (LSTM inputSize hiddenSize numLayers directionality dtype device)
  where
  sample :: LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
sample (LSTMSpec DropoutSpec
dropoutSpec) =
    LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
forall (numLayers :: Natural) (inputSize :: Natural)
       (hiddenSize :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
(1 <= numLayers) =>
LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
-> Dropout
-> LSTM inputSize hiddenSize numLayers directionality dtype device
LSTM
      (LSTMLayerStack
   inputSize hiddenSize numLayers directionality dtype device
 -> Dropout
 -> LSTM inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Dropout
      -> LSTM inputSize hiddenSize numLayers directionality dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTMLayerStack
        inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMLayerStackSpec
  inputSize hiddenSize numLayers directionality dtype device
LSTMLayerStackSpec @inputSize @hiddenSize @numLayers @directionality @dtype @device)
      IO
  (Dropout
   -> LSTM inputSize hiddenSize numLayers directionality dtype device)
-> IO Dropout
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropoutSpec

-- | A specification for a long, short-term memory layer.
data
  LSTMWithInitSpec
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (initialization :: RNNInitialization)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases and cell states.
  LSTMWithZerosInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
    LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases
  --   and user-provided cell states.
  LSTMWithConstInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
    -- | The initial values of the memory cell
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    -- | The initial values of the hidden state
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device
  -- | Weights drawn from Xavier-Uniform
  --   with zeros-value initialized biases
  --   and learned cell states.
  LSTMWithLearnedInitSpec ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    LSTMSpec inputSize hiddenSize numLayers directionality dtype device ->
    -- | The initial (learnable)
    -- values of the memory cell
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    -- | The initial (learnable)
    -- values of the hidden state
    Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize] ->
    LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device

deriving instance Show (LSTMWithInitSpec inputSize hiddenSize numLayers directionality initialization dtype device)

-- | A long, short-term memory layer with either fixed initial
-- states for the memory cells and hidden state or learnable
-- inital states for the memory cells and hidden state.
data
  LSTMWithInit
    (inputSize :: Nat)
    (hiddenSize :: Nat)
    (numLayers :: Nat)
    (directionality :: RNNDirectionality)
    (initialization :: RNNInitialization)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  LSTMWithConstInit ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    { forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device,
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize],
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
    } ->
    LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device
  LSTMWithLearnedInit ::
    forall inputSize hiddenSize numLayers directionality dtype device.
    { forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device,
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize],
      forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize]
    } ->
    LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'LearnedInitialization
      dtype
      device

deriving instance Show (LSTMWithInit inputSize hiddenSize numLayers directionality initialization dtype device)

instance Generic (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) where
  type
    Rep (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device) =
      Rec0 (LSTM inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
        :*: Rec0 (Tensor device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
  from :: forall x.
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Rep
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
     x
from (LSTMWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..}) = LSTM inputSize hiddenSize numLayers directionality dtype device
-> K1
     R
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
     x
forall k i c (p :: k). c -> K1 i c p
K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm K1
  R
  (LSTM inputSize hiddenSize numLayers directionality dtype device)
  x
-> (:*:)
     (K1
        R
        (Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     (K1
        R
        (Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
-> (:*:)
     (K1
        R
        (LSTM inputSize hiddenSize numLayers directionality dtype device))
     (K1
        R
        (Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize])
      :*: K1
            R
            (Tensor
               device
               dtype
               '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> K1
     R
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
forall k i c (p :: k). c -> K1 i c p
K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c K1
  R
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize])
  x
-> K1
     R
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
-> (:*:)
     (K1
        R
        (Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     (K1
        R
        (Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> K1
     R
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
forall k i c (p :: k). c -> K1 i c p
K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h
  to :: forall x.
Rep
  (LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device)
  x
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
to (K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstm :*: K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c :*: K1 Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) = LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
LSTMWithConstInit LSTM inputSize hiddenSize numLayers directionality dtype device
lstm Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance Generic (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) where
  type
    Rep (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device) =
      Rec0 (LSTM inputSize hiddenSize numLayers directionality dtype device)
        :*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
        :*: Rec0 (Parameter device dtype '[numLayers * NumberOfDirections directionality, hiddenSize])
  from :: forall x.
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Rep
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
     x
from (LSTMWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..}) = LSTM inputSize hiddenSize numLayers directionality dtype device
-> K1
     R
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
     x
forall k i c (p :: k). c -> K1 i c p
K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm K1
  R
  (LSTM inputSize hiddenSize numLayers directionality dtype device)
  x
-> (:*:)
     (K1
        R
        (Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     (K1
        R
        (Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
-> (:*:)
     (K1
        R
        (LSTM inputSize hiddenSize numLayers directionality dtype device))
     (K1
        R
        (Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize])
      :*: K1
            R
            (Parameter
               device
               dtype
               '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> K1
     R
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
forall k i c (p :: k). c -> K1 i c p
K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c K1
  R
  (Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize])
  x
-> K1
     R
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
-> (:*:)
     (K1
        R
        (Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     (K1
        R
        (Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]))
     x
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> K1
     R
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
     x
forall k i c (p :: k). c -> K1 i c p
K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h
  to :: forall x.
Rep
  (LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device)
  x
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
to (K1 LSTM inputSize hiddenSize numLayers directionality dtype device
lstm :*: K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c :*: K1 Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) = LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
LSTMWithLearnedInit LSTM inputSize hiddenSize numLayers directionality dtype device
lstm Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance
  ( Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device))
      '[]
      (Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device) ++ '[])
  ) =>
  Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)

instance
  ( Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
    HAppendFD
      (Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device))
      '[ Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize],
         Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]
       ]
      ( Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device)
          ++ '[ Parameter
                  device
                  dtype
                  '[numLayers * NumberOfDirections directionality, hiddenSize],
                Parameter
                  device
                  dtype
                  '[numLayers * NumberOfDirections directionality, hiddenSize]
              ]
      )
  ) =>
  Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)

instance
  ( KnownNat hiddenSize,
    KnownNat numLayers,
    KnownNat (NumberOfDirections directionality),
    KnownDType dtype,
    KnownDevice device,
    A.Randomizable
      (LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
      (LSTM inputSize hiddenSize numLayers directionality dtype device)
  ) =>
  A.Randomizable
    (LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
    (LSTMWithInit inputSize hiddenSize numLayers directionality 'ConstantInitialization dtype device)
  where
  sample :: LSTMWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
sample (LSTMWithZerosInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec) =
    LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
LSTMWithConstInit
      (LSTM inputSize hiddenSize numLayers directionality dtype device
 -> Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device)
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> LSTMWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'ConstantInitialization
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
      IO
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> LSTMWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'ConstantInitialization
           dtype
           device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
      IO
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros
  sample (LSTMWithConstInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
    LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'ConstantInitialization
     dtype
     device
LSTMWithConstInit
      (LSTM inputSize hiddenSize numLayers directionality dtype device
 -> Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'ConstantInitialization
      dtype
      device)
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> Tensor
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> LSTMWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'ConstantInitialization
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
      IO
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> LSTMWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'ConstantInitialization
           dtype
           device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c
      IO
  (Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'ConstantInitialization
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h

instance
  ( KnownNat hiddenSize,
    KnownNat numLayers,
    KnownNat (NumberOfDirections directionality),
    KnownDType dtype,
    KnownDevice device,
    A.Randomizable
      (LSTMSpec inputSize hiddenSize numLayers directionality dtype device)
      (LSTM inputSize hiddenSize numLayers directionality dtype device)
  ) =>
  A.Randomizable
    (LSTMWithInitSpec inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
    (LSTMWithInit inputSize hiddenSize numLayers directionality 'LearnedInitialization dtype device)
  where
  sample :: LSTMWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
sample s :: LSTMWithInitSpec
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
s@(LSTMWithLearnedInitSpec LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h) =
    LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTM inputSize hiddenSize numLayers directionality dtype device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     'LearnedInitialization
     dtype
     device
LSTMWithLearnedInit
      (LSTM inputSize hiddenSize numLayers directionality dtype device
 -> Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      'LearnedInitialization
      dtype
      device)
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> Parameter
           device
           dtype
           '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> LSTMWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'LearnedInitialization
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LSTMSpec inputSize hiddenSize numLayers directionality dtype device
-> IO
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LSTMSpec inputSize hiddenSize numLayers directionality dtype device
lstmSpec
      IO
  (Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize]
      -> LSTMWithInit
           inputSize
           hiddenSize
           numLayers
           directionality
           'LearnedInitialization
           dtype
           device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> IO
      (Parameter
         device
         dtype
         '[numLayers * NumberOfDirections directionality, hiddenSize]))
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
c)
      IO
  (Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
   -> LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        'LearnedInitialization
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> IO
      (Parameter
         device
         dtype
         '[numLayers * NumberOfDirections directionality, hiddenSize]))
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
-> IO
     (Parameter
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> IO
     (Tensor
        device
        dtype
        '[numLayers * NumberOfDirections directionality, hiddenSize])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
h)

instance A.Parameterized (LSTMWithInit inputSize hiddenSize numLayers directionality initialization dtype device) where
  flattenParameters :: LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> [Parameter]
flattenParameters LSTMWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..} =
    LSTM inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm
  flattenParameters LSTMWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..} =
    LSTM inputSize hiddenSize numLayers directionality dtype device
-> [Parameter]
forall f. Parameterized f => f -> [Parameter]
A.flattenParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm
      [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ (Parameter
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Parameter)
-> [Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]]
-> [Parameter]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Parameter
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter device dtype shape -> Parameter
untypeParam [Item
  [Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]]
Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c, Item
  [Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]]
Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h]
  _replaceParameters :: LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> ParamStream
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        initialization
        dtype
        device)
_replaceParameters LSTMWithConstInit {Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'ConstantInitialization
  dtype
  device
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..} = do
    LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm' <- LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm
    LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> ParamStream
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        initialization
        dtype
        device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (LSTMWithInit
   inputSize
   hiddenSize
   numLayers
   directionality
   initialization
   dtype
   device
 -> ParamStream
      (LSTMWithInit
         inputSize
         hiddenSize
         numLayers
         directionality
         initialization
         dtype
         device))
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> ParamStream
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        initialization
        dtype
        device)
forall a b. (a -> b) -> a -> b
$
      LSTMWithConstInit
        { lstmWithConstInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm = LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithConstInit_lstm',
          Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_c :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithConstInit_h :: Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..
        }
  _replaceParameters LSTMWithLearnedInit {Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_c :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: forall (inputSize :: Natural) (hiddenSize :: Natural)
       (numLayers :: Natural) (directionality :: RNNDirectionality)
       (dtype :: DType) (device :: (DeviceType, Natural)).
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  'LearnedInitialization
  dtype
  device
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
..} = do
    LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm' <- LSTM inputSize hiddenSize numLayers directionality dtype device
-> ParamStream
     (LSTM inputSize hiddenSize numLayers directionality dtype device)
forall f. Parameterized f => f -> ParamStream f
A._replaceParameters LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm
    Parameter
lstmWithLearnedInit_c' <- ParamStream Parameter
A.nextParameter
    Parameter
lstmWithLearnedInit_h' <- ParamStream Parameter
A.nextParameter
    LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> ParamStream
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        initialization
        dtype
        device)
forall a. a -> StateT [Parameter] Identity a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (LSTMWithInit
   inputSize
   hiddenSize
   numLayers
   directionality
   initialization
   dtype
   device
 -> ParamStream
      (LSTMWithInit
         inputSize
         hiddenSize
         numLayers
         directionality
         initialization
         dtype
         device))
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> ParamStream
     (LSTMWithInit
        inputSize
        hiddenSize
        numLayers
        directionality
        initialization
        dtype
        device)
forall a b. (a -> b) -> a -> b
$
      LSTMWithLearnedInit
        { lstmWithLearnedInit_lstm :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm = LSTM inputSize hiddenSize numLayers directionality dtype device
lstmWithLearnedInit_lstm',
          lstmWithLearnedInit_c :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_c = Parameter
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
lstmWithLearnedInit_c',
          lstmWithLearnedInit_h :: Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
lstmWithLearnedInit_h = Parameter
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (device :: (DeviceType, Natural)) (dtype :: DType)
       (shape :: [Natural]).
Parameter -> Parameter device dtype shape
UnsafeMkParameter Parameter
lstmWithLearnedInit_h'
        }

lstmForward ::
  forall
    shapeOrder
    batchSize
    seqLen
    directionality
    initialization
    numLayers
    inputSize
    outputSize
    hiddenSize
    inputShape
    outputShape
    hxShape
    parameters
    tensorParameters
    dtype
    device.
  ( KnownNat (NumberOfDirections directionality),
    KnownNat numLayers,
    KnownNat batchSize,
    KnownNat hiddenSize,
    KnownRNNShapeOrder shapeOrder,
    KnownRNNDirectionality directionality,
    outputSize ~ (hiddenSize * NumberOfDirections directionality),
    inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
    outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
    hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
    parameters ~ Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device),
    Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
    tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
    ATen.Castable (HList tensorParameters) [D.ATenTensor],
    HMap' ToDependent parameters tensorParameters
  ) =>
  Bool ->
  LSTMWithInit
    inputSize
    hiddenSize
    numLayers
    directionality
    initialization
    dtype
    device ->
  Tensor device dtype inputShape ->
  ( Tensor device dtype outputShape,
    Tensor device dtype hxShape,
    Tensor device dtype hxShape
  )
lstmForward :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForward Bool
dropoutOn (LSTMWithConstInit lstmModel :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel@(LSTM LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
  forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hxShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
forall (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hxShape :: [Natural]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstm
    @shapeOrder
    @directionality
    @numLayers
    @seqLen
    @batchSize
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @tensorParameters
    @dtype
    @device
    (ToDependent -> HList parameters -> HList tensorParameters
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList parameters -> HList tensorParameters)
-> (LSTM inputSize hiddenSize numLayers directionality dtype device
    -> HList parameters)
-> LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList parameters
LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (LSTM inputSize hiddenSize numLayers directionality dtype device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (LSTM inputSize hiddenSize numLayers directionality dtype device
 -> HList tensorParameters)
-> LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall a b. (a -> b) -> a -> b
$ LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel)
    Double
dropoutProb
    Bool
dropoutOn
    (Tensor device dtype hxShape
cc', Tensor device dtype hxShape
hc')
    Tensor device dtype inputShape
input
  where
    cc' :: Tensor device dtype hxShape
cc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        (Tensor
   device
   dtype
   '[batchSize, numLayers * NumberOfDirections directionality,
     hiddenSize]
 -> Tensor device dtype hxShape)
-> (Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[batchSize, numLayers * NumberOfDirections directionality,
           hiddenSize])
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor device dtype hxShape)
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc
    hc' :: Tensor device dtype hxShape
hc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        (Tensor
   device
   dtype
   '[batchSize, numLayers * NumberOfDirections directionality,
     hiddenSize]
 -> Tensor device dtype hxShape)
-> (Tensor
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[batchSize, numLayers * NumberOfDirections directionality,
           hiddenSize])
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor device dtype hxShape)
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc
lstmForward Bool
dropoutOn (LSTMWithLearnedInit lstmModel :: LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel@(LSTM LSTMLayerStack
  inputSize hiddenSize numLayers directionality dtype device
_ (Dropout Double
dropoutProb)) Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc) Tensor device dtype inputShape
input =
  forall {k} (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hxShape :: [Natural]) (tensorParameters :: [k]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
forall (shapeOrder :: RNNShapeOrder)
       (directionality :: RNNDirectionality) (numLayers :: Natural)
       (seqLen :: Natural) (batchSize :: Natural) (inputSize :: Natural)
       (outputSize :: Natural) (hiddenSize :: Natural)
       (inputShape :: [Natural]) (outputShape :: [Natural])
       (hxShape :: [Natural]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat numLayers, KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor]) =>
HList tensorParameters
-> Double
-> Bool
-> (Tensor device dtype hxShape, Tensor device dtype hxShape)
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstm
    @shapeOrder
    @directionality
    @numLayers
    @seqLen
    @batchSize
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @tensorParameters
    @dtype
    @device
    (ToDependent -> HList parameters -> HList tensorParameters
forall k f (xs :: [k]) (ys :: [k]).
HMap' f xs ys =>
f -> HList xs -> HList ys
hmap' ToDependent
ToDependent (HList parameters -> HList tensorParameters)
-> (LSTM inputSize hiddenSize numLayers directionality dtype device
    -> HList parameters)
-> LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList parameters
LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList
     (Parameters
        (LSTM inputSize hiddenSize numLayers directionality dtype device))
forall f. Parameterized f => f -> HList (Parameters f)
flattenParameters (LSTM inputSize hiddenSize numLayers directionality dtype device
 -> HList tensorParameters)
-> LSTM inputSize hiddenSize numLayers directionality dtype device
-> HList tensorParameters
forall a b. (a -> b) -> a -> b
$ LSTM inputSize hiddenSize numLayers directionality dtype device
lstmModel)
    Double
dropoutProb
    Bool
dropoutOn
    (Tensor device dtype hxShape
cc', Tensor device dtype hxShape
hc')
    Tensor device dtype inputShape
input
  where
    cc' :: Tensor device dtype hxShape
cc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        (Tensor
   device
   dtype
   '[batchSize, numLayers * NumberOfDirections directionality,
     hiddenSize]
 -> Tensor device dtype hxShape)
-> (Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[batchSize, numLayers * NumberOfDirections directionality,
           hiddenSize])
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor
      device
      dtype
      '[batchSize, numLayers * NumberOfDirections directionality,
        hiddenSize])
-> (Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[numLayers * NumberOfDirections directionality, hiddenSize])
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[batchSize, numLayers * NumberOfDirections directionality,
       hiddenSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
        (Parameter
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor device dtype hxShape)
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall a b. (a -> b) -> a -> b
$ Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
cc
    hc' :: Tensor device dtype hxShape
hc' =
      forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @hxShape
        (Tensor
   device
   dtype
   '[batchSize, numLayers * NumberOfDirections directionality,
     hiddenSize]
 -> Tensor device dtype hxShape)
-> (Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[batchSize, numLayers * NumberOfDirections directionality,
           hiddenSize])
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Natural]) (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand
          @'[batchSize, numLayers * NumberOfDirections directionality, hiddenSize]
          Bool
False -- TODO: What does the bool do?
        (Tensor
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor
      device
      dtype
      '[batchSize, numLayers * NumberOfDirections directionality,
        hiddenSize])
-> (Parameter
      device
      dtype
      '[numLayers * NumberOfDirections directionality, hiddenSize]
    -> Tensor
         device
         dtype
         '[numLayers * NumberOfDirections directionality, hiddenSize])
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[batchSize, numLayers * NumberOfDirections directionality,
       hiddenSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
forall (shape :: [Natural]) (dtype :: DType)
       (device :: (DeviceType, Natural)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent
        (Parameter
   device
   dtype
   '[numLayers * NumberOfDirections directionality, hiddenSize]
 -> Tensor device dtype hxShape)
-> Parameter
     device
     dtype
     '[numLayers * NumberOfDirections directionality, hiddenSize]
-> Tensor device dtype hxShape
forall a b. (a -> b) -> a -> b
$ Parameter
  device
  dtype
  '[numLayers * NumberOfDirections directionality, hiddenSize]
hc

lstmForwardWithDropout,
  lstmForwardWithoutDropout ::
    forall
      shapeOrder
      batchSize
      seqLen
      directionality
      initialization
      numLayers
      inputSize
      outputSize
      hiddenSize
      inputShape
      outputShape
      hxShape
      parameters
      tensorParameters
      dtype
      device.
    ( KnownNat (NumberOfDirections directionality),
      KnownNat numLayers,
      KnownNat batchSize,
      KnownNat hiddenSize,
      KnownRNNShapeOrder shapeOrder,
      KnownRNNDirectionality directionality,
      outputSize ~ (hiddenSize * NumberOfDirections directionality),
      inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
      outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
      hxShape ~ '[numLayers * NumberOfDirections directionality, batchSize, hiddenSize],
      parameters ~ Parameters (LSTM inputSize hiddenSize numLayers directionality dtype device),
      Parameterized (LSTM inputSize hiddenSize numLayers directionality dtype device),
      tensorParameters ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
      ATen.Castable (HList tensorParameters) [D.ATenTensor],
      HMap' ToDependent parameters tensorParameters
    ) =>
    LSTMWithInit
      inputSize
      hiddenSize
      numLayers
      directionality
      initialization
      dtype
      device ->
    Tensor device dtype inputShape ->
    ( Tensor device dtype outputShape,
      Tensor device dtype hxShape,
      Tensor device dtype hxShape
    )
-- ^ Forward propagate the `LSTM` module and apply dropout on the outputs of each layer.
--
-- >>> input :: CPUTensor 'D.Float '[5,16,10] <- randn
-- >>> spec = LSTMWithZerosInitSpec @10 @30 @3 @'Bidirectional @'D.Float @'(D.CPU, 0) (LSTMSpec (DropoutSpec 0.5))
-- >>> model <- A.sample spec
-- >>> :t lstmForwardWithDropout @'BatchFirst model input
-- lstmForwardWithDropout @'BatchFirst model input
--   :: (Tensor '(D.CPU, 0) 'D.Float [5, 16, 60],
--       Tensor '(D.CPU, 0) 'D.Float [6, 5, 30],
--       Tensor '(D.CPU, 0) 'D.Float [6, 5, 30])
-- >>> (a,b,c) = lstmForwardWithDropout @'BatchFirst model input
-- >>> ((dtype a, shape a), (dtype b, shape b), (dtype c, shape c))
-- ((Float,[5,16,60]),(Float,[6,5,30]),(Float,[6,5,30]))
lstmForwardWithDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForwardWithDropout =
  forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForward
    @shapeOrder
    @batchSize
    @seqLen
    @directionality
    @initialization
    @numLayers
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @parameters
    @tensorParameters
    @dtype
    @device
    Bool
True
-- ^ Forward propagate the `LSTM` module (without applying dropout on the outputs of each layer).
--
-- >>> input :: CPUTensor 'D.Float '[5,16,10] <- randn
-- >>> spec = LSTMWithZerosInitSpec @10 @30 @3 @'Bidirectional @'D.Float @'(D.CPU, 0) (LSTMSpec (DropoutSpec 0.5))
-- >>> model <- A.sample spec
-- >>> :t lstmForwardWithoutDropout @'BatchFirst model input
-- lstmForwardWithoutDropout @'BatchFirst model input
--   :: (Tensor '(D.CPU, 0) 'D.Float [5, 16, 60],
--       Tensor '(D.CPU, 0) 'D.Float [6, 5, 30],
--       Tensor '(D.CPU, 0) 'D.Float [6, 5, 30])
-- >>> (a,b,c) = lstmForwardWithoutDropout @'BatchFirst model input
-- >>> ((dtype a, shape a), (dtype b, shape b), (dtype c, shape c))
-- ((Float,[5,16,60]),(Float,[6,5,30]),(Float,[6,5,30]))
lstmForwardWithoutDropout :: forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
LSTMWithInit
  inputSize
  hiddenSize
  numLayers
  directionality
  initialization
  dtype
  device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForwardWithoutDropout =
  forall (shapeOrder :: RNNShapeOrder) (batchSize :: Natural)
       (seqLen :: Natural) (directionality :: RNNDirectionality)
       (initialization :: RNNInitialization) (numLayers :: Natural)
       (inputSize :: Natural) (outputSize :: Natural)
       (hiddenSize :: Natural) (inputShape :: [Natural])
       (outputShape :: [Natural]) (hxShape :: [Natural])
       (parameters :: [Type]) (tensorParameters :: [Type])
       (dtype :: DType) (device :: (DeviceType, Natural)).
(KnownNat (NumberOfDirections directionality), KnownNat numLayers,
 KnownNat batchSize, KnownNat hiddenSize,
 KnownRNNShapeOrder shapeOrder,
 KnownRNNDirectionality directionality,
 outputSize ~ (hiddenSize * NumberOfDirections directionality),
 inputShape ~ RNNShape shapeOrder seqLen batchSize inputSize,
 outputShape ~ RNNShape shapeOrder seqLen batchSize outputSize,
 hxShape
 ~ '[numLayers * NumberOfDirections directionality, batchSize,
     hiddenSize],
 parameters
 ~ Parameters
     (LSTM inputSize hiddenSize numLayers directionality dtype device),
 Parameterized
   (LSTM inputSize hiddenSize numLayers directionality dtype device),
 tensorParameters
 ~ LSTMR inputSize hiddenSize numLayers directionality dtype device,
 Castable (HList tensorParameters) [ATenTensor],
 HMap' ToDependent parameters tensorParameters) =>
Bool
-> LSTMWithInit
     inputSize
     hiddenSize
     numLayers
     directionality
     initialization
     dtype
     device
-> Tensor device dtype inputShape
-> (Tensor device dtype outputShape, Tensor device dtype hxShape,
    Tensor device dtype hxShape)
lstmForward
    @shapeOrder
    @batchSize
    @seqLen
    @directionality
    @initialization
    @numLayers
    @inputSize
    @outputSize
    @hiddenSize
    @inputShape
    @outputShape
    @hxShape
    @parameters
    @tensorParameters
    @dtype
    @device
    Bool
False