{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module Torch.Typed.NN.Recurrent.Cell.LSTM where

import Data.List
  ( foldl',
    scanl',
  )
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import qualified Torch.NN as A
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (linear)
import Torch.Typed.NN.Dropout
import Torch.Typed.Parameter
import Torch.Typed.Tensor

-- | A specification for a long, short-term memory (LSTM) cell.
data
  LSTMCellSpec
    (inputDim :: Nat)
    (hiddenDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  = -- | Weights and biases are drawn from the standard normal distibution (having mean 0 and variance 1)
    LSTMCellSpec
  deriving (Int -> LSTMCellSpec inputDim hiddenDim dtype device -> ShowS
[LSTMCellSpec inputDim hiddenDim dtype device] -> ShowS
LSTMCellSpec inputDim hiddenDim dtype device -> String
(Int -> LSTMCellSpec inputDim hiddenDim dtype device -> ShowS)
-> (LSTMCellSpec inputDim hiddenDim dtype device -> String)
-> ([LSTMCellSpec inputDim hiddenDim dtype device] -> ShowS)
-> Show (LSTMCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LSTMCellSpec inputDim hiddenDim dtype device -> ShowS
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LSTMCellSpec inputDim hiddenDim dtype device] -> ShowS
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LSTMCellSpec inputDim hiddenDim dtype device -> ShowS
showsPrec :: Int -> LSTMCellSpec inputDim hiddenDim dtype device -> ShowS
$cshow :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device -> String
show :: LSTMCellSpec inputDim hiddenDim dtype device -> String
$cshowList :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LSTMCellSpec inputDim hiddenDim dtype device] -> ShowS
showList :: [LSTMCellSpec inputDim hiddenDim dtype device] -> ShowS
Show, LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
(LSTMCellSpec inputDim hiddenDim dtype device
 -> LSTMCellSpec inputDim hiddenDim dtype device -> Bool)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device -> Bool)
-> Eq (LSTMCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
== :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
$c/= :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
/= :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
Eq, Eq (LSTMCellSpec inputDim hiddenDim dtype device)
Eq (LSTMCellSpec inputDim hiddenDim dtype device) =>
(LSTMCellSpec inputDim hiddenDim dtype device
 -> LSTMCellSpec inputDim hiddenDim dtype device -> Ordering)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device -> Bool)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device -> Bool)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device -> Bool)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device -> Bool)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device)
-> Ord (LSTMCellSpec inputDim hiddenDim dtype device)
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Ordering
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Eq (LSTMCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Ordering
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Ordering
compare :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Ordering
$c< :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
< :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
$c<= :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
<= :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
$c> :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
> :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
$c>= :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
>= :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device -> Bool
$cmax :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
max :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
$cmin :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
min :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
Ord, (forall x.
 LSTMCellSpec inputDim hiddenDim dtype device
 -> Rep (LSTMCellSpec inputDim hiddenDim dtype device) x)
-> (forall x.
    Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
    -> LSTMCellSpec inputDim hiddenDim dtype device)
-> Generic (LSTMCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
-> LSTMCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
LSTMCellSpec inputDim hiddenDim dtype device
-> Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
forall x.
Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
-> LSTMCellSpec inputDim hiddenDim dtype device
forall x.
LSTMCellSpec inputDim hiddenDim dtype device
-> Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
LSTMCellSpec inputDim hiddenDim dtype device
-> Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
from :: forall x.
LSTMCellSpec inputDim hiddenDim dtype device
-> Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
$cto :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
-> LSTMCellSpec inputDim hiddenDim dtype device
to :: forall x.
Rep (LSTMCellSpec inputDim hiddenDim dtype device) x
-> LSTMCellSpec inputDim hiddenDim dtype device
Generic, Int -> LSTMCellSpec inputDim hiddenDim dtype device
LSTMCellSpec inputDim hiddenDim dtype device -> Int
LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
(LSTMCellSpec inputDim hiddenDim dtype device
 -> LSTMCellSpec inputDim hiddenDim dtype device)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device)
-> (Int -> LSTMCellSpec inputDim hiddenDim dtype device)
-> (LSTMCellSpec inputDim hiddenDim dtype device -> Int)
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> [LSTMCellSpec inputDim hiddenDim dtype device])
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device
    -> [LSTMCellSpec inputDim hiddenDim dtype device])
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device
    -> [LSTMCellSpec inputDim hiddenDim dtype device])
-> (LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device
    -> LSTMCellSpec inputDim hiddenDim dtype device
    -> [LSTMCellSpec inputDim hiddenDim dtype device])
-> Enum (LSTMCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LSTMCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device -> Int
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
succ :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
$cpred :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
pred :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
$ctoEnum :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LSTMCellSpec inputDim hiddenDim dtype device
toEnum :: Int -> LSTMCellSpec inputDim hiddenDim dtype device
$cfromEnum :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device -> Int
fromEnum :: LSTMCellSpec inputDim hiddenDim dtype device -> Int
$cenumFrom :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
enumFrom :: LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
$cenumFromThen :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
enumFromThen :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
$cenumFromTo :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
enumFromTo :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
$cenumFromThenTo :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
enumFromThenTo :: LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> [LSTMCellSpec inputDim hiddenDim dtype device]
Enum, LSTMCellSpec inputDim hiddenDim dtype device
LSTMCellSpec inputDim hiddenDim dtype device
-> LSTMCellSpec inputDim hiddenDim dtype device
-> Bounded (LSTMCellSpec inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
forall a. a -> a -> Bounded a
$cminBound :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
minBound :: LSTMCellSpec inputDim hiddenDim dtype device
$cmaxBound :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCellSpec inputDim hiddenDim dtype device
maxBound :: LSTMCellSpec inputDim hiddenDim dtype device
Bounded)

-- | A long, short-term memory cell.
data
  LSTMCell
    (inputDim :: Nat)
    (hiddenDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat)) = LSTMCell
  { -- | input-to-hidden weights
    forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim, inputDim]
lstmCell_w_ih :: Parameter device dtype '[4 * hiddenDim, inputDim],
    -- | hidden-to-hidden weights
    forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim, hiddenDim]
lstmCell_w_hh :: Parameter device dtype '[4 * hiddenDim, hiddenDim],
    -- | input-to-hidden bias
    forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim]
lstmCell_b_ih :: Parameter device dtype '[4 * hiddenDim],
    -- | hidden-to-hidden bias
    forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim]
lstmCell_b_hh :: Parameter device dtype '[4 * hiddenDim]
  }
  deriving (Int -> LSTMCell inputDim hiddenDim dtype device -> ShowS
[LSTMCell inputDim hiddenDim dtype device] -> ShowS
LSTMCell inputDim hiddenDim dtype device -> String
(Int -> LSTMCell inputDim hiddenDim dtype device -> ShowS)
-> (LSTMCell inputDim hiddenDim dtype device -> String)
-> ([LSTMCell inputDim hiddenDim dtype device] -> ShowS)
-> Show (LSTMCell inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LSTMCell inputDim hiddenDim dtype device -> ShowS
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LSTMCell inputDim hiddenDim dtype device] -> ShowS
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> LSTMCell inputDim hiddenDim dtype device -> ShowS
showsPrec :: Int -> LSTMCell inputDim hiddenDim dtype device -> ShowS
$cshow :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device -> String
show :: LSTMCell inputDim hiddenDim dtype device -> String
$cshowList :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[LSTMCell inputDim hiddenDim dtype device] -> ShowS
showList :: [LSTMCell inputDim hiddenDim dtype device] -> ShowS
Show, (forall x.
 LSTMCell inputDim hiddenDim dtype device
 -> Rep (LSTMCell inputDim hiddenDim dtype device) x)
-> (forall x.
    Rep (LSTMCell inputDim hiddenDim dtype device) x
    -> LSTMCell inputDim hiddenDim dtype device)
-> Generic (LSTMCell inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (LSTMCell inputDim hiddenDim dtype device) x
-> LSTMCell inputDim hiddenDim dtype device
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
LSTMCell inputDim hiddenDim dtype device
-> Rep (LSTMCell inputDim hiddenDim dtype device) x
forall x.
Rep (LSTMCell inputDim hiddenDim dtype device) x
-> LSTMCell inputDim hiddenDim dtype device
forall x.
LSTMCell inputDim hiddenDim dtype device
-> Rep (LSTMCell inputDim hiddenDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
LSTMCell inputDim hiddenDim dtype device
-> Rep (LSTMCell inputDim hiddenDim dtype device) x
from :: forall x.
LSTMCell inputDim hiddenDim dtype device
-> Rep (LSTMCell inputDim hiddenDim dtype device) x
$cto :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (LSTMCell inputDim hiddenDim dtype device) x
-> LSTMCell inputDim hiddenDim dtype device
to :: forall x.
Rep (LSTMCell inputDim hiddenDim dtype device) x
-> LSTMCell inputDim hiddenDim dtype device
Generic, LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
-> LSTMCell inputDim hiddenDim dtype device
(LSTMCell inputDim hiddenDim dtype device
 -> HList (Parameters (LSTMCell inputDim hiddenDim dtype device)))
-> (LSTMCell inputDim hiddenDim dtype device
    -> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
    -> LSTMCell inputDim hiddenDim dtype device)
-> Parameterized (LSTMCell inputDim hiddenDim dtype device)
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
-> LSTMCell inputDim hiddenDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
flattenParameters :: LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
$creplaceParameters :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
-> LSTMCell inputDim hiddenDim dtype device
replaceParameters :: LSTMCell inputDim hiddenDim dtype device
-> HList (Parameters (LSTMCell inputDim hiddenDim dtype device))
-> LSTMCell inputDim hiddenDim dtype device
Parameterized)

instance
  ( KnownDevice device,
    KnownDType dtype,
    KnownNat inputDim,
    KnownNat hiddenDim,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (LSTMCellSpec inputDim hiddenDim dtype device)
    (LSTMCell inputDim hiddenDim dtype device)
  where
  sample :: LSTMCellSpec inputDim hiddenDim dtype device
-> IO (LSTMCell inputDim hiddenDim dtype device)
sample LSTMCellSpec inputDim hiddenDim dtype device
LSTMCellSpec =
    Parameter device dtype '[4 * hiddenDim, inputDim]
-> Parameter device dtype '[4 * hiddenDim, hiddenDim]
-> Parameter device dtype '[4 * hiddenDim]
-> Parameter device dtype '[4 * hiddenDim]
-> LSTMCell inputDim hiddenDim dtype device
forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype '[4 * hiddenDim, inputDim]
-> Parameter device dtype '[4 * hiddenDim, hiddenDim]
-> Parameter device dtype '[4 * hiddenDim]
-> Parameter device dtype '[4 * hiddenDim]
-> LSTMCell inputDim hiddenDim dtype device
LSTMCell
      (Parameter device dtype '[4 * hiddenDim, inputDim]
 -> Parameter device dtype '[4 * hiddenDim, hiddenDim]
 -> Parameter device dtype '[4 * hiddenDim]
 -> Parameter device dtype '[4 * hiddenDim]
 -> LSTMCell inputDim hiddenDim dtype device)
-> IO (Parameter device dtype '[4 * hiddenDim, inputDim])
-> IO
     (Parameter device dtype '[4 * hiddenDim, hiddenDim]
      -> Parameter device dtype '[4 * hiddenDim]
      -> Parameter device dtype '[4 * hiddenDim]
      -> LSTMCell inputDim hiddenDim dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[4 * hiddenDim, inputDim]
-> IO (Parameter device dtype '[4 * hiddenDim, inputDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenDim, inputDim]
 -> IO (Parameter device dtype '[4 * hiddenDim, inputDim]))
-> IO (Tensor device dtype '[4 * hiddenDim, inputDim])
-> IO (Parameter device dtype '[4 * hiddenDim, inputDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[4 * hiddenDim, inputDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
      IO
  (Parameter device dtype '[4 * hiddenDim, hiddenDim]
   -> Parameter device dtype '[4 * hiddenDim]
   -> Parameter device dtype '[4 * hiddenDim]
   -> LSTMCell inputDim hiddenDim dtype device)
-> IO (Parameter device dtype '[4 * hiddenDim, hiddenDim])
-> IO
     (Parameter device dtype '[4 * hiddenDim]
      -> Parameter device dtype '[4 * hiddenDim]
      -> LSTMCell inputDim hiddenDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenDim, hiddenDim]
-> IO (Parameter device dtype '[4 * hiddenDim, hiddenDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenDim, hiddenDim]
 -> IO (Parameter device dtype '[4 * hiddenDim, hiddenDim]))
-> IO (Tensor device dtype '[4 * hiddenDim, hiddenDim])
-> IO (Parameter device dtype '[4 * hiddenDim, hiddenDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[4 * hiddenDim, hiddenDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
      IO
  (Parameter device dtype '[4 * hiddenDim]
   -> Parameter device dtype '[4 * hiddenDim]
   -> LSTMCell inputDim hiddenDim dtype device)
-> IO (Parameter device dtype '[4 * hiddenDim])
-> IO
     (Parameter device dtype '[4 * hiddenDim]
      -> LSTMCell inputDim hiddenDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenDim]
-> IO (Parameter device dtype '[4 * hiddenDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenDim]
 -> IO (Parameter device dtype '[4 * hiddenDim]))
-> IO (Tensor device dtype '[4 * hiddenDim])
-> IO (Parameter device dtype '[4 * hiddenDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[4 * hiddenDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
      IO
  (Parameter device dtype '[4 * hiddenDim]
   -> LSTMCell inputDim hiddenDim dtype device)
-> IO (Parameter device dtype '[4 * hiddenDim])
-> IO (LSTMCell inputDim hiddenDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Tensor device dtype '[4 * hiddenDim]
-> IO (Parameter device dtype '[4 * hiddenDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[4 * hiddenDim]
 -> IO (Parameter device dtype '[4 * hiddenDim]))
-> IO (Tensor device dtype '[4 * hiddenDim])
-> IO (Parameter device dtype '[4 * hiddenDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[4 * hiddenDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)

-- | A single recurrent step of an `LSTMCell`
lstmCellForward ::
  forall inputDim hiddenDim batchSize dtype device.
  ( KnownDType dtype,
    KnownNat inputDim,
    KnownNat hiddenDim,
    KnownNat batchSize
  ) =>
  -- | The cell
  LSTMCell inputDim hiddenDim dtype device ->
  -- | The current (Hidden, Cell) state
  ( Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim]
  ) ->
  -- | The input
  Tensor device dtype '[batchSize, inputDim] ->
  -- | The subsequent (Hidden, Cell) state
  ( Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim]
  )
lstmCellForward :: forall (inputDim :: Nat) (hiddenDim :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
 KnownNat batchSize) =>
LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> Tensor device dtype '[batchSize, inputDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
lstmCellForward LSTMCell {Parameter device dtype '[4 * hiddenDim, inputDim]
Parameter device dtype '[4 * hiddenDim, hiddenDim]
Parameter device dtype '[4 * hiddenDim]
lstmCell_w_ih :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim, inputDim]
lstmCell_w_hh :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim, hiddenDim]
lstmCell_b_ih :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim]
lstmCell_b_hh :: forall (inputDim :: Nat) (hiddenDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
LSTMCell inputDim hiddenDim dtype device
-> Parameter device dtype '[4 * hiddenDim]
lstmCell_w_ih :: Parameter device dtype '[4 * hiddenDim, inputDim]
lstmCell_w_hh :: Parameter device dtype '[4 * hiddenDim, hiddenDim]
lstmCell_b_ih :: Parameter device dtype '[4 * hiddenDim]
lstmCell_b_hh :: Parameter device dtype '[4 * hiddenDim]
..} =
  Tensor device dtype '[4 * hiddenDim, inputDim]
-> Tensor device dtype '[4 * hiddenDim, hiddenDim]
-> Tensor device dtype '[4 * hiddenDim]
-> Tensor device dtype '[4 * hiddenDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> Tensor device dtype '[batchSize, inputDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
forall (inputSize :: Nat) (hiddenSize :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[4 * hiddenSize, inputSize]
-> Tensor device dtype '[4 * hiddenSize, hiddenSize]
-> Tensor device dtype '[4 * hiddenSize]
-> Tensor device dtype '[4 * hiddenSize]
-> (Tensor device dtype '[batchSize, hiddenSize],
    Tensor device dtype '[batchSize, hiddenSize])
-> Tensor device dtype '[batchSize, inputSize]
-> (Tensor device dtype '[batchSize, hiddenSize],
    Tensor device dtype '[batchSize, hiddenSize])
lstmCell
    (Parameter device dtype '[4 * hiddenDim, inputDim]
-> Tensor device dtype '[4 * hiddenDim, inputDim]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[4 * hiddenDim, inputDim]
lstmCell_w_ih)
    (Parameter device dtype '[4 * hiddenDim, hiddenDim]
-> Tensor device dtype '[4 * hiddenDim, hiddenDim]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[4 * hiddenDim, hiddenDim]
lstmCell_w_hh)
    (Parameter device dtype '[4 * hiddenDim]
-> Tensor device dtype '[4 * hiddenDim]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[4 * hiddenDim]
lstmCell_b_ih)
    (Parameter device dtype '[4 * hiddenDim]
-> Tensor device dtype '[4 * hiddenDim]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[4 * hiddenDim]
lstmCell_b_hh)

-- | foldl' for lists of tensors unsing an `LSTMCell`
lstmCellFold ::
  forall inputDim hiddenDim batchSize dtype device.
  ( KnownDType dtype,
    KnownNat inputDim,
    KnownNat hiddenDim,
    KnownNat batchSize
  ) =>
  LSTMCell inputDim hiddenDim dtype device ->
  -- | The initial (Hidden, Cell) state
  ( Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim]
  ) ->
  -- | The list of inputs
  [Tensor device dtype '[batchSize, inputDim]] ->
  -- | The final (Hidden, Cell) state
  ( Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim]
  )
lstmCellFold :: forall (inputDim :: Nat) (hiddenDim :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
 KnownNat batchSize) =>
LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> [Tensor device dtype '[batchSize, inputDim]]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
lstmCellFold LSTMCell inputDim hiddenDim dtype device
cell = ((Tensor device dtype '[batchSize, hiddenDim],
  Tensor device dtype '[batchSize, hiddenDim])
 -> Tensor device dtype '[batchSize, inputDim]
 -> (Tensor device dtype '[batchSize, hiddenDim],
     Tensor device dtype '[batchSize, hiddenDim]))
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> [Tensor device dtype '[batchSize, inputDim]]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> Tensor device dtype '[batchSize, inputDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
forall (inputDim :: Nat) (hiddenDim :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
 KnownNat batchSize) =>
LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> Tensor device dtype '[batchSize, inputDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
lstmCellForward LSTMCell inputDim hiddenDim dtype device
cell)

-- | scanl' for lists of tensors unsing an `LSTMCell`
lstmCellScan ::
  forall inputDim hiddenDim batchSize dtype device.
  ( KnownDType dtype,
    KnownNat inputDim,
    KnownNat hiddenDim,
    KnownNat batchSize
  ) =>
  LSTMCell inputDim hiddenDim dtype device ->
  -- | The initial (Hidden, Cell) state
  ( Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim]
  ) ->
  -- | The list of inputs
  [Tensor device dtype '[batchSize, inputDim]] ->
  -- | All subsequent (Hidden, Cell) states
  [ ( Tensor device dtype '[batchSize, hiddenDim],
      Tensor device dtype '[batchSize, hiddenDim]
    )
  ]
lstmCellScan :: forall (inputDim :: Nat) (hiddenDim :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
 KnownNat batchSize) =>
LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> [Tensor device dtype '[batchSize, inputDim]]
-> [(Tensor device dtype '[batchSize, hiddenDim],
     Tensor device dtype '[batchSize, hiddenDim])]
lstmCellScan LSTMCell inputDim hiddenDim dtype device
cell = ((Tensor device dtype '[batchSize, hiddenDim],
  Tensor device dtype '[batchSize, hiddenDim])
 -> Tensor device dtype '[batchSize, inputDim]
 -> (Tensor device dtype '[batchSize, hiddenDim],
     Tensor device dtype '[batchSize, hiddenDim]))
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> [Tensor device dtype '[batchSize, inputDim]]
-> [(Tensor device dtype '[batchSize, hiddenDim],
     Tensor device dtype '[batchSize, hiddenDim])]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl' (LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> Tensor device dtype '[batchSize, inputDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
forall (inputDim :: Nat) (hiddenDim :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownDType dtype, KnownNat inputDim, KnownNat hiddenDim,
 KnownNat batchSize) =>
LSTMCell inputDim hiddenDim dtype device
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
-> Tensor device dtype '[batchSize, inputDim]
-> (Tensor device dtype '[batchSize, hiddenDim],
    Tensor device dtype '[batchSize, hiddenDim])
lstmCellForward LSTMCell inputDim hiddenDim dtype device
cell)