{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}

module Torch.Typed.NN.Transformer where

import Control.Monad
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..))
import qualified Torch.NN as A
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (linear, log)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Linear
import Torch.Typed.NN.Normalization
import Torch.Typed.NN.Sparse
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (cos, exp, sin)

residual :: (Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
      (CheckBroadcast
         shape
         shape'
         (ComputeBroadcast
            (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
    -> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor
  device
  (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
  (CheckBroadcast
     shape
     shape'
     (ComputeBroadcast
        (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g Tensor device dtype shape
x = Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor device dtype shape
x m (Tensor device dtype' shape')
-> (Tensor device dtype' shape' -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\Tensor device dtype' shape'
x' -> Tensor
  device
  (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
  (CheckBroadcast
     shape
     shape'
     (ComputeBroadcast
        (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g (Tensor device dtype shape
x Tensor device dtype shape
-> Tensor device dtype' shape'
-> Tensor
     device
     (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
     (CheckBroadcast
        shape
        shape'
        (ComputeBroadcast
           (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype' shape'
x'))

--------------------------------------------------------------------------------
-- Relation-Aware Multi-Headed Attention Layer
--------------------------------------------------------------------------------

data
  MultiheadAttentionSpec
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  MultiheadAttentionSpec ::
    -- | spec for dropout
    DropoutSpec ->
    MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device
  deriving (Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
[MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
(Int
 -> MultiheadAttentionSpec
      embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> ShowS)
-> (MultiheadAttentionSpec
      embedDim kEmbedDim vEmbedDim numHeads dtype device
    -> String)
-> ([MultiheadAttentionSpec
       embedDim kEmbedDim vEmbedDim numHeads dtype device]
    -> ShowS)
-> Show
     (MultiheadAttentionSpec
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
showsPrec :: Int
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
show :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
showList :: [MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
Show, MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
(MultiheadAttentionSpec
   embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> MultiheadAttentionSpec
      embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> Bool)
-> (MultiheadAttentionSpec
      embedDim kEmbedDim vEmbedDim numHeads dtype device
    -> MultiheadAttentionSpec
         embedDim kEmbedDim vEmbedDim numHeads dtype device
    -> Bool)
-> Eq
     (MultiheadAttentionSpec
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
== :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
/= :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
Eq)

data
  MultiheadAttention
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  MultiheadAttention ::
    { -- | in-projection for query
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaQInProj :: Linear embedDim embedDim dtype device,
      -- | in-projection for key
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device,
      -- | in-projection for value
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device,
      -- | out-projection
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaOutProj :: Linear embedDim embedDim dtype device,
      -- | dropout
      forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaDropout :: Dropout
    } ->
    MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device
  deriving (Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
[MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
(Int
 -> MultiheadAttention
      embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> ShowS)
-> (MultiheadAttention
      embedDim kEmbedDim vEmbedDim numHeads dtype device
    -> String)
-> ([MultiheadAttention
       embedDim kEmbedDim vEmbedDim numHeads dtype device]
    -> ShowS)
-> Show
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
showsPrec :: Int
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
show :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
showList :: [MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
Show, (forall x.
 MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> Rep
      (MultiheadAttention
         embedDim kEmbedDim vEmbedDim numHeads dtype device)
      x)
-> (forall x.
    Rep
      (MultiheadAttention
         embedDim kEmbedDim vEmbedDim numHeads dtype device)
      x
    -> MultiheadAttention
         embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> Generic
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device)
  x
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
     x
forall x.
Rep
  (MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device)
  x
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
forall x.
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
     x
from :: forall x.
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
     x
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device)
  x
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
to :: forall x.
Rep
  (MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device)
  x
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
Generic, MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
(MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> HList
      (Parameters
         (MultiheadAttention
            embedDim kEmbedDim vEmbedDim numHeads dtype device)))
-> (MultiheadAttention
      embedDim kEmbedDim vEmbedDim numHeads dtype device
    -> HList
         (Parameters
            (MultiheadAttention
               embedDim kEmbedDim vEmbedDim numHeads dtype device))
    -> MultiheadAttention
         embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> Parameterized
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
flattenParameters :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
replaceParameters :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
     (Parameters
        (MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
Parameterized)

multiheadAttention ::
  forall embedDim kEmbedDim vEmbedDim numHeads seqLen seqLen' batchSize headDim dtype device.
  ( 1 <= numHeads,
    embedDim ~ (headDim * numHeads),
    All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype,
    MatMulDTypeIsValid device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    dtype ~ SumDType dtype,
    SumDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  -- | multi-head attention model ADT
  MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device ->
  -- | switch between training mode and evaluation mode (turns random dropout on and off)
  Bool ->
  -- | optional attention mask
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
  -- | optional key padding mask
  Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
  -- | optional key relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | optional value relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | query representation
  Tensor device dtype '[batchSize, seqLen', embedDim] ->
  -- | key representation
  Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
  -- | value representation
  Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
  -- | attention and attention averaged over heads
  IO
    ( Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen]
    )
multiheadAttention :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
       (batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
     (Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention {Linear embedDim embedDim dtype device
Linear kEmbedDim embedDim dtype device
Linear vEmbedDim embedDim dtype device
Dropout
mhaQInProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaKInProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaVInProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaOutProj :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaDropout :: forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaQInProj :: Linear embedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device
mhaOutProj :: Linear embedDim embedDim dtype device
mhaDropout :: Dropout
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value = do
  Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights <-
    Dropout
-> Bool
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
     (Tensor
        device
        dtype
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
mhaDropout Bool
train
      (Tensor
   device
   dtype
   (CheckMatMul
      '[batchSize, numHeads, seqLen', headDim]
      '[batchSize, numHeads, headDim, seqLen]
      (ComputeMatMul
         '[headDim, seqLen', numHeads, batchSize]
         (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
 -> IO
      (Tensor
         device
         dtype
         (CheckMatMul
            '[batchSize, numHeads, seqLen', headDim]
            '[batchSize, numHeads, headDim, seqLen]
            (ComputeMatMul
               '[headDim, seqLen', numHeads, batchSize]
               (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))))
-> (Tensor
      device
      dtype
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
    -> Tensor
         device
         dtype
         (CheckMatMul
            '[batchSize, numHeads, seqLen', headDim]
            '[batchSize, numHeads, headDim, seqLen]
            (ComputeMatMul
               '[headDim, seqLen', numHeads, batchSize]
               (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
     (Tensor
        device
        dtype
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax @3
      (Tensor
   device
   dtype
   (CheckMatMul
      '[batchSize, numHeads, seqLen', headDim]
      '[batchSize, numHeads, headDim, seqLen]
      (ComputeMatMul
         '[headDim, seqLen', numHeads, batchSize]
         (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
 -> Tensor
      device
      dtype
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> (Tensor
      device
      dtype
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
    -> Tensor
         device
         dtype
         (CheckMatMul
            '[batchSize, numHeads, seqLen', headDim]
            '[batchSize, numHeads, headDim, seqLen]
            (ComputeMatMul
               '[headDim, seqLen', numHeads, batchSize]
               (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings
      (Tensor
   device
   dtype
   (CheckMatMul
      '[batchSize, numHeads, seqLen', headDim]
      '[batchSize, numHeads, headDim, seqLen]
      (ComputeMatMul
         '[headDim, seqLen', numHeads, batchSize]
         (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
 -> Tensor
      device
      dtype
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> (Tensor
      device
      dtype
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
    -> Tensor
         device
         dtype
         (CheckMatMul
            '[batchSize, numHeads, seqLen', headDim]
            '[batchSize, numHeads, headDim, seqLen]
            (ComputeMatMul
               '[headDim, seqLen', numHeads, batchSize]
               (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention
      (Tensor
   device
   dtype
   (CheckMatMul
      '[batchSize, numHeads, seqLen', headDim]
      '[batchSize, numHeads, headDim, seqLen]
      (ComputeMatMul
         '[headDim, seqLen', numHeads, batchSize]
         (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
 -> IO
      (Tensor
         device
         dtype
         (CheckMatMul
            '[batchSize, numHeads, seqLen', headDim]
            '[batchSize, numHeads, headDim, seqLen]
            (ComputeMatMul
               '[headDim, seqLen', numHeads, batchSize]
               (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
     (Tensor
        device
        dtype
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights
  (Tensor device dtype '[batchSize, seqLen', embedDim],
 Tensor device dtype '[batchSize, seqLen', seqLen])
-> IO
     (Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights, Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights)
  where
    _attentionWeights :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights =
      let scaling :: Double
scaling = Double -> Double
forall a. Floating a => a -> a
Prelude.sqrt (Double -> Double) -> (Int -> Double) -> Int -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @headDim :: Double
          q :: Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q = Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen', headDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
    -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Double
scaling (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
    -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear embedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaQInProj (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen', headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen', embedDim]
query
          k :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k = Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' (Tensor device dtype '[batchSize, seqLen, embedDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> (Tensor device dtype '[batchSize, seqLen, kEmbedDim]
    -> Tensor device dtype '[batchSize, seqLen, embedDim])
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear kEmbedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear kEmbedDim embedDim dtype device
mhaKInProj (Tensor device dtype '[batchSize, seqLen, kEmbedDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key
          weights :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights = Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
-> Tensor device dtype '[batchSize, numHeads, headDim, seqLen]
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k)
          weights' :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations of
            Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights
            Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (SetValue
        (SetValue
           (CheckMatMul
              '[batchSize, seqLen', numHeads, headDim]
              '[batchSize, seqLen', headDim, seqLen]
              (ComputeMatMul
                 (ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
                 (ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
           1
           (GetValue
              (CheckMatMul
                 '[batchSize, seqLen', numHeads, headDim]
                 '[batchSize, seqLen', headDim, seqLen]
                 (ComputeMatMul
                    (ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
                    (ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
              2))
        2
        (GetValue
           (CheckMatMul
              '[batchSize, seqLen', numHeads, headDim]
              '[batchSize, seqLen', headDim, seqLen]
              (ComputeMatMul
                 (ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
                 (ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
           1))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 ((forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q) Tensor device dtype '[batchSize, seqLen', numHeads, headDim]
-> Tensor device dtype '[batchSize, seqLen', headDim, seqLen]
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, seqLen', numHeads, headDim]
        '[batchSize, seqLen', headDim, seqLen]
        (ComputeMatMul
           (ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
           (ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
`matmul` (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr))
       in Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights'
    _maskAttention :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
      case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask of
        Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
Nothing -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
        Just Tensor device dtype '[batchSize, seqLen', seqLen]
am -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, 1, seqLen', seqLen]
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 Tensor device dtype '[batchSize, seqLen', seqLen]
am
    _maskKeyPaddings :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
      case Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask of
        Maybe (Tensor device 'Bool '[batchSize, seqLen])
Nothing -> Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
        Just Tensor device 'Bool '[batchSize, seqLen]
kpm ->
          let keyPaddingMask' :: Tensor
  device
  'Bool
  (UnsqueezeCheck
     '[batchSize, 1, seqLen]
     2
     (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @2 (Tensor device 'Bool '[batchSize, 1, seqLen]
 -> Tensor
      device
      'Bool
      (UnsqueezeCheck
         '[batchSize, 1, seqLen]
         2
         (UnsqueezeImpl '[batchSize, 1, seqLen] 2)))
-> (Tensor device 'Bool '[batchSize, seqLen]
    -> Tensor device 'Bool '[batchSize, 1, seqLen])
-> Tensor device 'Bool '[batchSize, seqLen]
-> Tensor
     device
     'Bool
     (UnsqueezeCheck
        '[batchSize, 1, seqLen]
        2
        (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 (Tensor device 'Bool '[batchSize, seqLen]
 -> Tensor
      device
      'Bool
      (UnsqueezeCheck
         '[batchSize, 1, seqLen]
         2
         (UnsqueezeImpl '[batchSize, 1, seqLen] 2)))
-> Tensor device 'Bool '[batchSize, seqLen]
-> Tensor
     device
     'Bool
     (UnsqueezeCheck
        '[batchSize, 1, seqLen]
        2
        (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[batchSize, seqLen]
kpm
           in Tensor
  device
  'Bool
  (UnsqueezeCheck
     '[batchSize, 1, seqLen]
     2
     (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
-> Double
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor
  device
  'Bool
  (UnsqueezeCheck
     '[batchSize, 1, seqLen]
     2
     (UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' (-Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
    _attention :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
      let v :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v = Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' (Tensor device dtype '[batchSize, seqLen, embedDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> (Tensor device dtype '[batchSize, seqLen, vEmbedDim]
    -> Tensor device dtype '[batchSize, seqLen, embedDim])
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear vEmbedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear vEmbedDim embedDim dtype device
mhaVInProj (Tensor device dtype '[batchSize, seqLen, vEmbedDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
          attention :: Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 (Tensor
   device
   dtype
   (CheckMatMul
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
      '[batchSize, numHeads, seqLen, headDim]
      (ComputeMatMul
         (ReverseImpl
            (CheckMatMul
               '[batchSize, numHeads, seqLen', headDim]
               '[batchSize, numHeads, headDim, seqLen]
               (ComputeMatMul
                  '[headDim, seqLen', numHeads, batchSize]
                  (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
            '[])
         '[headDim, seqLen, numHeads, batchSize]))
 -> Tensor
      device
      dtype
      (SetValue
         (SetValue
            (CheckMatMul
               (CheckMatMul
                  '[batchSize, numHeads, seqLen', headDim]
                  '[batchSize, numHeads, headDim, seqLen]
                  (ComputeMatMul
                     '[headDim, seqLen', numHeads, batchSize]
                     (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
               '[batchSize, numHeads, seqLen, headDim]
               (ComputeMatMul
                  (ReverseImpl
                     (CheckMatMul
                        '[batchSize, numHeads, seqLen', headDim]
                        '[batchSize, numHeads, headDim, seqLen]
                        (ComputeMatMul
                           '[headDim, seqLen', numHeads, batchSize]
                           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                     '[])
                  '[headDim, seqLen, numHeads, batchSize]))
            1
            (GetValue
               (CheckMatMul
                  (CheckMatMul
                     '[batchSize, numHeads, seqLen', headDim]
                     '[batchSize, numHeads, headDim, seqLen]
                     (ComputeMatMul
                        '[headDim, seqLen', numHeads, batchSize]
                        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                  '[batchSize, numHeads, seqLen, headDim]
                  (ComputeMatMul
                     (ReverseImpl
                        (CheckMatMul
                           '[batchSize, numHeads, seqLen', headDim]
                           '[batchSize, numHeads, headDim, seqLen]
                           (ComputeMatMul
                              '[headDim, seqLen', numHeads, batchSize]
                              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                        '[])
                     '[headDim, seqLen, numHeads, batchSize]))
               2))
         2
         (GetValue
            (CheckMatMul
               (CheckMatMul
                  '[batchSize, numHeads, seqLen', headDim]
                  '[batchSize, numHeads, headDim, seqLen]
                  (ComputeMatMul
                     '[headDim, seqLen', numHeads, batchSize]
                     (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
               '[batchSize, numHeads, seqLen, headDim]
               (ComputeMatMul
                  (ReverseImpl
                     (CheckMatMul
                        '[batchSize, numHeads, seqLen', headDim]
                        '[batchSize, numHeads, headDim, seqLen]
                        (ComputeMatMul
                           '[headDim, seqLen', numHeads, batchSize]
                           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                     '[])
                  '[headDim, seqLen, numHeads, batchSize]))
            1)))
-> Tensor
     device
     dtype
     (CheckMatMul
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
        '[batchSize, numHeads, seqLen, headDim]
        (ComputeMatMul
           (ReverseImpl
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[])
           '[headDim, seqLen, numHeads, batchSize]))
-> Tensor
     device
     dtype
     (SetValue
        (SetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1
           (GetValue
              (CheckMatMul
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[batchSize, numHeads, seqLen, headDim]
                 (ComputeMatMul
                    (ReverseImpl
                       (CheckMatMul
                          '[batchSize, numHeads, seqLen', headDim]
                          '[batchSize, numHeads, headDim, seqLen]
                          (ComputeMatMul
                             '[headDim, seqLen', numHeads, batchSize]
                             (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                       '[])
                    '[headDim, seqLen, numHeads, batchSize]))
              2))
        2
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1))
forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
-> Tensor
     device
     dtype
     (CheckMatMul
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
        '[batchSize, numHeads, seqLen, headDim]
        (ComputeMatMul
           (ReverseImpl
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[])
           '[headDim, seqLen, numHeads, batchSize]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v
          attention' :: Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations of
            Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention
            Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr -> Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
-> Tensor
     device
     dtype
     (CheckMatMul
        (SetValue
           (SetValue
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              1
              (GetValue
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 2))
           2
           (GetValue
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              1))
        '[batchSize, seqLen', seqLen, headDim]
        (ComputeMatMul
           (ReverseImpl
              (SetValue
                 (SetValue
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    1
                    (GetValue
                       (CheckMatMul
                          '[batchSize, numHeads, seqLen', headDim]
                          '[batchSize, numHeads, headDim, seqLen]
                          (ComputeMatMul
                             '[headDim, seqLen', numHeads, batchSize]
                             (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                       2))
                 2
                 (GetValue
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    1))
              '[])
           '[headDim, seqLen, seqLen', batchSize]))
-> Tensor
     device
     dtype
     (SetValue
        (SetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1
           (GetValue
              (CheckMatMul
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[batchSize, numHeads, seqLen, headDim]
                 (ComputeMatMul
                    (ReverseImpl
                       (CheckMatMul
                          '[batchSize, numHeads, seqLen', headDim]
                          '[batchSize, numHeads, headDim, seqLen]
                          (ComputeMatMul
                             '[headDim, seqLen', numHeads, batchSize]
                             (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                       '[])
                    '[headDim, seqLen, numHeads, batchSize]))
              2))
        2
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` (Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
        1
        (GetValue
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           2))
     2
     (GetValue
        (CheckMatMul
           '[batchSize, numHeads, seqLen', headDim]
           '[batchSize, numHeads, headDim, seqLen]
           (ComputeMatMul
              '[headDim, seqLen', numHeads, batchSize]
              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
        1))
-> Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
-> Tensor
     device
     dtype
     (CheckMatMul
        (SetValue
           (SetValue
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              1
              (GetValue
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 2))
           2
           (GetValue
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              1))
        '[batchSize, seqLen', seqLen, headDim]
        (ComputeMatMul
           (ReverseImpl
              (SetValue
                 (SetValue
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    1
                    (GetValue
                       (CheckMatMul
                          '[batchSize, numHeads, seqLen', headDim]
                          '[batchSize, numHeads, headDim, seqLen]
                          (ComputeMatMul
                             '[headDim, seqLen', numHeads, batchSize]
                             (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                       2))
                 2
                 (GetValue
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    1))
              '[])
           '[headDim, seqLen, seqLen', batchSize]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights) Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr)
       in Linear embedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaOutProj (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor
      device
      dtype
      (SetValue
         (SetValue
            (CheckMatMul
               (CheckMatMul
                  '[batchSize, numHeads, seqLen', headDim]
                  '[batchSize, numHeads, headDim, seqLen]
                  (ComputeMatMul
                     '[headDim, seqLen', numHeads, batchSize]
                     (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
               '[batchSize, numHeads, seqLen, headDim]
               (ComputeMatMul
                  (ReverseImpl
                     (CheckMatMul
                        '[batchSize, numHeads, seqLen', headDim]
                        '[batchSize, numHeads, headDim, seqLen]
                        (ComputeMatMul
                           '[headDim, seqLen', numHeads, batchSize]
                           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                     '[])
                  '[headDim, seqLen, numHeads, batchSize]))
            1
            (GetValue
               (CheckMatMul
                  (CheckMatMul
                     '[batchSize, numHeads, seqLen', headDim]
                     '[batchSize, numHeads, headDim, seqLen]
                     (ComputeMatMul
                        '[headDim, seqLen', numHeads, batchSize]
                        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                  '[batchSize, numHeads, seqLen, headDim]
                  (ComputeMatMul
                     (ReverseImpl
                        (CheckMatMul
                           '[batchSize, numHeads, seqLen', headDim]
                           '[batchSize, numHeads, headDim, seqLen]
                           (ComputeMatMul
                              '[headDim, seqLen', numHeads, batchSize]
                              (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                        '[])
                     '[headDim, seqLen, numHeads, batchSize]))
               2))
         2
         (GetValue
            (CheckMatMul
               (CheckMatMul
                  '[batchSize, numHeads, seqLen', headDim]
                  '[batchSize, numHeads, headDim, seqLen]
                  (ComputeMatMul
                     '[headDim, seqLen', numHeads, batchSize]
                     (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
               '[batchSize, numHeads, seqLen, headDim]
               (ComputeMatMul
                  (ReverseImpl
                     (CheckMatMul
                        '[batchSize, numHeads, seqLen', headDim]
                        '[batchSize, numHeads, headDim, seqLen]
                        (ComputeMatMul
                           '[headDim, seqLen', numHeads, batchSize]
                           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                     '[])
                  '[headDim, seqLen, numHeads, batchSize]))
            1))
    -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor
     device
     dtype
     (SetValue
        (SetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1
           (GetValue
              (CheckMatMul
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[batchSize, numHeads, seqLen, headDim]
                 (ComputeMatMul
                    (ReverseImpl
                       (CheckMatMul
                          '[batchSize, numHeads, seqLen', headDim]
                          '[batchSize, numHeads, headDim, seqLen]
                          (ComputeMatMul
                             '[headDim, seqLen', numHeads, batchSize]
                             (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                       '[])
                    '[headDim, seqLen, numHeads, batchSize]))
              2))
        2
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen', embedDim] (Tensor
   device
   dtype
   (SetValue
      (SetValue
         (CheckMatMul
            (CheckMatMul
               '[batchSize, numHeads, seqLen', headDim]
               '[batchSize, numHeads, headDim, seqLen]
               (ComputeMatMul
                  '[headDim, seqLen', numHeads, batchSize]
                  (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
            '[batchSize, numHeads, seqLen, headDim]
            (ComputeMatMul
               (ReverseImpl
                  (CheckMatMul
                     '[batchSize, numHeads, seqLen', headDim]
                     '[batchSize, numHeads, headDim, seqLen]
                     (ComputeMatMul
                        '[headDim, seqLen', numHeads, batchSize]
                        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                  '[])
               '[headDim, seqLen, numHeads, batchSize]))
         1
         (GetValue
            (CheckMatMul
               (CheckMatMul
                  '[batchSize, numHeads, seqLen', headDim]
                  '[batchSize, numHeads, headDim, seqLen]
                  (ComputeMatMul
                     '[headDim, seqLen', numHeads, batchSize]
                     (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
               '[batchSize, numHeads, seqLen, headDim]
               (ComputeMatMul
                  (ReverseImpl
                     (CheckMatMul
                        '[batchSize, numHeads, seqLen', headDim]
                        '[batchSize, numHeads, headDim, seqLen]
                        (ComputeMatMul
                           '[headDim, seqLen', numHeads, batchSize]
                           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                     '[])
                  '[headDim, seqLen, numHeads, batchSize]))
            2))
      2
      (GetValue
         (CheckMatMul
            (CheckMatMul
               '[batchSize, numHeads, seqLen', headDim]
               '[batchSize, numHeads, headDim, seqLen]
               (ComputeMatMul
                  '[headDim, seqLen', numHeads, batchSize]
                  (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
            '[batchSize, numHeads, seqLen, headDim]
            (ComputeMatMul
               (ReverseImpl
                  (CheckMatMul
                     '[batchSize, numHeads, seqLen', headDim]
                     '[batchSize, numHeads, headDim, seqLen]
                     (ComputeMatMul
                        '[headDim, seqLen', numHeads, batchSize]
                        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                  '[])
               '[headDim, seqLen, numHeads, batchSize]))
         1))
 -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor
     device
     dtype
     (SetValue
        (SetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1
           (GetValue
              (CheckMatMul
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[batchSize, numHeads, seqLen, headDim]
                 (ComputeMatMul
                    (ReverseImpl
                       (CheckMatMul
                          '[batchSize, numHeads, seqLen', headDim]
                          '[batchSize, numHeads, headDim, seqLen]
                          (ComputeMatMul
                             '[headDim, seqLen', numHeads, batchSize]
                             (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                       '[])
                    '[headDim, seqLen, numHeads, batchSize]))
              2))
        2
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           1))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall a b. (a -> b) -> a -> b
$ Tensor
  device
  dtype
  (SetValue
     (SetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1
        (GetValue
           (CheckMatMul
              (CheckMatMul
                 '[batchSize, numHeads, seqLen', headDim]
                 '[batchSize, numHeads, headDim, seqLen]
                 (ComputeMatMul
                    '[headDim, seqLen', numHeads, batchSize]
                    (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
              '[batchSize, numHeads, seqLen, headDim]
              (ComputeMatMul
                 (ReverseImpl
                    (CheckMatMul
                       '[batchSize, numHeads, seqLen', headDim]
                       '[batchSize, numHeads, headDim, seqLen]
                       (ComputeMatMul
                          '[headDim, seqLen', numHeads, batchSize]
                          (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                    '[])
                 '[headDim, seqLen, numHeads, batchSize]))
           2))
     2
     (GetValue
        (CheckMatMul
           (CheckMatMul
              '[batchSize, numHeads, seqLen', headDim]
              '[batchSize, numHeads, headDim, seqLen]
              (ComputeMatMul
                 '[headDim, seqLen', numHeads, batchSize]
                 (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
           '[batchSize, numHeads, seqLen, headDim]
           (ComputeMatMul
              (ReverseImpl
                 (CheckMatMul
                    '[batchSize, numHeads, seqLen', headDim]
                    '[batchSize, numHeads, headDim, seqLen]
                    (ComputeMatMul
                       '[headDim, seqLen', numHeads, batchSize]
                       (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
                 '[])
              '[headDim, seqLen, numHeads, batchSize]))
        1))
attention'
    averageOverHeads :: Tensor
  device
  dtype
  (CheckMatMul
     '[batchSize, numHeads, seqLen', headDim]
     '[batchSize, numHeads, headDim, seqLen]
     (ComputeMatMul
        '[headDim, seqLen', numHeads, batchSize]
        (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads =
      let numHeads' :: Int
numHeads' = forall (n :: Nat). KnownNat n => Int
natValI @numHeads
       in Int
-> Tensor device dtype '[batchSize, seqLen', seqLen]
-> Tensor device dtype '[batchSize, seqLen', seqLen]
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Int
numHeads' (Tensor device dtype '[batchSize, seqLen', seqLen]
 -> Tensor device dtype '[batchSize, seqLen', seqLen])
-> (Tensor
      device
      dtype
      (CheckMatMul
         '[batchSize, numHeads, seqLen', headDim]
         '[batchSize, numHeads, headDim, seqLen]
         (ComputeMatMul
            '[headDim, seqLen', numHeads, batchSize]
            (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
    -> Tensor device dtype '[batchSize, seqLen', seqLen])
-> Tensor
     device
     dtype
     (CheckMatMul
        '[batchSize, numHeads, seqLen', headDim]
        '[batchSize, numHeads, headDim, seqLen]
        (ComputeMatMul
           '[headDim, seqLen', numHeads, batchSize]
           (ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
 SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim @1
    reshape' ::
      forall seqLen''.
      KnownNat seqLen'' =>
      Tensor device dtype '[batchSize, seqLen'', embedDim] ->
      Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
    reshape' :: forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 (Tensor device dtype '[batchSize, seqLen'', numHeads, headDim]
 -> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim])
-> (Tensor device dtype '[batchSize, seqLen'', embedDim]
    -> Tensor device dtype '[batchSize, seqLen'', numHeads, headDim])
-> Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen'', numHeads, headDim]

instance
  ( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads],
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device)
    (MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)
  where
  sample :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> IO
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
sample (MultiheadAttentionSpec DropoutSpec
mhaDropoutSpec) =
    Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
forall (embedDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat).
Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
MultiheadAttention
      (Linear embedDim embedDim dtype device
 -> Linear kEmbedDim embedDim dtype device
 -> Linear vEmbedDim embedDim dtype device
 -> Linear embedDim embedDim dtype device
 -> Dropout
 -> MultiheadAttention
      embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear embedDim embedDim dtype device)
-> IO
     (Linear kEmbedDim embedDim dtype device
      -> Linear vEmbedDim embedDim dtype device
      -> Linear embedDim embedDim dtype device
      -> Dropout
      -> MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LinearSpec embedDim embedDim dtype device
-> IO (Linear embedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      IO
  (Linear kEmbedDim embedDim dtype device
   -> Linear vEmbedDim embedDim dtype device
   -> Linear embedDim embedDim dtype device
   -> Dropout
   -> MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear kEmbedDim embedDim dtype device)
-> IO
     (Linear vEmbedDim embedDim dtype device
      -> Linear embedDim embedDim dtype device
      -> Dropout
      -> MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec kEmbedDim embedDim dtype device
-> IO (Linear kEmbedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec kEmbedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      IO
  (Linear vEmbedDim embedDim dtype device
   -> Linear embedDim embedDim dtype device
   -> Dropout
   -> MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear vEmbedDim embedDim dtype device)
-> IO
     (Linear embedDim embedDim dtype device
      -> Dropout
      -> MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec vEmbedDim embedDim dtype device
-> IO (Linear vEmbedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec vEmbedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      IO
  (Linear embedDim embedDim dtype device
   -> Dropout
   -> MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear embedDim embedDim dtype device)
-> IO
     (Dropout
      -> MultiheadAttention
           embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec embedDim embedDim dtype device
-> IO (Linear embedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      IO
  (Dropout
   -> MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO Dropout
-> IO
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
mhaDropoutSpec

--------------------------------------------------------------------------------
-- Transformer MLP Layer
--------------------------------------------------------------------------------

data
  TransformerMLPSpec
    (embedDim :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerMLPSpec ::
    forall embedDim ffnDim dtype device.
    { -- | spec for relu dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout0Spec :: DropoutSpec,
      -- | spec for other dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout1Spec :: DropoutSpec,
      -- | epsilon for layer norm
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
epsSpec :: Double
    } ->
    TransformerMLPSpec embedDim ffnDim dtype device
  deriving (Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
TransformerMLPSpec embedDim ffnDim dtype device -> String
(Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS)
-> (TransformerMLPSpec embedDim ffnDim dtype device -> String)
-> ([TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS)
-> Show (TransformerMLPSpec embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
showsPrec :: Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
show :: TransformerMLPSpec embedDim ffnDim dtype device -> String
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
showList :: [TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
Show, TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
(TransformerMLPSpec embedDim ffnDim dtype device
 -> TransformerMLPSpec embedDim ffnDim dtype device -> Bool)
-> (TransformerMLPSpec embedDim ffnDim dtype device
    -> TransformerMLPSpec embedDim ffnDim dtype device -> Bool)
-> Eq (TransformerMLPSpec embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
== :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
$c/= :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
/= :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
Eq)

data
  TransformerMLP
    (embedDim :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerMLP ::
    forall embedDim ffnDim dtype device.
    { -- | first fully connected layer
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
linear0 :: Linear embedDim ffnDim dtype device,
      -- | second fully connected layer
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
linear1 :: Linear ffnDim embedDim dtype device,
      -- | relu dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout0 :: Dropout,
      -- | other dropout
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout1 :: Dropout,
      -- | layer norm
      forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
ln :: LayerNorm '[embedDim] dtype device
    } ->
    TransformerMLP embedDim ffnDim dtype device
  deriving (Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
TransformerMLP embedDim ffnDim dtype device -> String
(Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS)
-> (TransformerMLP embedDim ffnDim dtype device -> String)
-> ([TransformerMLP embedDim ffnDim dtype device] -> ShowS)
-> Show (TransformerMLP embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
showsPrec :: Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
show :: TransformerMLP embedDim ffnDim dtype device -> String
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
showList :: [TransformerMLP embedDim ffnDim dtype device] -> ShowS
Show, (forall x.
 TransformerMLP embedDim ffnDim dtype device
 -> Rep (TransformerMLP embedDim ffnDim dtype device) x)
-> (forall x.
    Rep (TransformerMLP embedDim ffnDim dtype device) x
    -> TransformerMLP embedDim ffnDim dtype device)
-> Generic (TransformerMLP embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
forall x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
forall x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
from :: forall x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
$cto :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
to :: forall x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
Generic, TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
(TransformerMLP embedDim ffnDim dtype device
 -> HList
      (Parameters (TransformerMLP embedDim ffnDim dtype device)))
-> (TransformerMLP embedDim ffnDim dtype device
    -> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
    -> TransformerMLP embedDim ffnDim dtype device)
-> Parameterized (TransformerMLP embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
flattenParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
$creplaceParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
replaceParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
Parameterized)

transformerMLP ::
  forall embedDim ffnDim seqLen batchSize dtype device.
  ( BasicArithmeticDTypeIsValid device dtype,
    StandardFloatingPointDTypeValidation device dtype,
    KnownNat embedDim,
    IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]
  ) =>
  -- | MLP model ADT for transformer
  TransformerMLP embedDim ffnDim dtype device ->
  -- | switch between training mode and evaluation mode (turns random dropout on and off)
  Bool ->
  Tensor device dtype '[seqLen, batchSize, embedDim] -> -- input
  IO (Tensor device dtype '[seqLen, batchSize, embedDim]) -- output
transformerMLP :: forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 StandardFloatingPointDTypeValidation device dtype,
 KnownNat embedDim,
 IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP {LayerNorm '[embedDim] dtype device
Linear embedDim ffnDim dtype device
Linear ffnDim embedDim dtype device
Dropout
linear0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
linear1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
dropout0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
ln :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
linear0 :: Linear embedDim ffnDim dtype device
linear1 :: Linear ffnDim embedDim dtype device
dropout0 :: Dropout
dropout1 :: Dropout
ln :: LayerNorm '[embedDim] dtype device
..} Bool
train Tensor device dtype '[seqLen, batchSize, embedDim]
input =
  (Tensor device dtype '[seqLen, batchSize, embedDim]
 -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
      (Broadcast
         '[seqLen, batchSize, embedDim] '[seqLen, batchSize, embedDim])
    -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall {device :: (DeviceType, Nat)} {dtype :: DType}
       {dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
       {shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid
   device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
 Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
      (CheckBroadcast
         shape
         shape'
         (ComputeBroadcast
            (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
    -> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
f (Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device dtype '[seqLen, batchSize, embedDim]
 -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
      (CheckBroadcast
         '[seqLen, batchSize, embedDim]
         '[seqLen, batchSize, embedDim]
         (ComputeBroadcast
            '[embedDim, batchSize, seqLen]
            (ReverseImpl '[seqLen, batchSize, embedDim] '[])))
    -> Tensor device dtype '[seqLen, batchSize, embedDim])
-> Tensor
     device
     (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
     (CheckBroadcast
        '[seqLen, batchSize, embedDim]
        '[seqLen, batchSize, embedDim]
        (ComputeBroadcast
           '[embedDim, batchSize, seqLen]
           (ReverseImpl '[seqLen, batchSize, embedDim] '[])))
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayerNorm '[embedDim] dtype device
-> Tensor
     device
     (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
     (CheckBroadcast
        '[seqLen, batchSize, embedDim]
        '[seqLen, batchSize, embedDim]
        (ComputeBroadcast
           '[embedDim, batchSize, seqLen]
           (ReverseImpl '[seqLen, batchSize, embedDim] '[])))
-> Tensor device dtype '[seqLen, batchSize, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
ln) Tensor device dtype '[seqLen, batchSize, embedDim]
input
  where
    f :: Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
f Tensor device dtype '[seqLen, batchSize, embedDim]
x =
      Dropout
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout1 Bool
train
        (Tensor device dtype '[seqLen, batchSize, embedDim]
 -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> (Tensor device dtype '[seqLen, batchSize, ffnDim]
    -> Tensor device dtype '[seqLen, batchSize, embedDim])
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear ffnDim embedDim dtype device
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
-> Tensor device dtype '[seqLen, batchSize, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear ffnDim embedDim dtype device
linear1
        (Tensor device dtype '[seqLen, batchSize, ffnDim]
 -> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Dropout
-> Bool
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout0 Bool
train
          (Tensor device dtype '[seqLen, batchSize, ffnDim]
 -> IO (Tensor device dtype '[seqLen, batchSize, ffnDim]))
-> (Tensor device dtype '[seqLen, batchSize, embedDim]
    -> Tensor device dtype '[seqLen, batchSize, ffnDim])
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor device dtype '[seqLen, batchSize, ffnDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
relu
          (Tensor device dtype '[seqLen, batchSize, ffnDim]
 -> Tensor device dtype '[seqLen, batchSize, ffnDim])
-> (Tensor device dtype '[seqLen, batchSize, embedDim]
    -> Tensor device dtype '[seqLen, batchSize, ffnDim])
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear embedDim ffnDim dtype device
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim ffnDim dtype device
linear0
        (Tensor device dtype '[seqLen, batchSize, embedDim]
 -> IO (Tensor device dtype '[seqLen, batchSize, ffnDim]))
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[seqLen, batchSize, embedDim]
x

instance
  ( All KnownNat '[embedDim, ffnDim],
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (TransformerMLPSpec embedDim ffnDim dtype device)
    (TransformerMLP embedDim ffnDim dtype device)
  where
  sample :: TransformerMLPSpec embedDim ffnDim dtype device
-> IO (TransformerMLP embedDim ffnDim dtype device)
sample TransformerMLPSpec {Double
DropoutSpec
dropout0Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout1Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
epsSpec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
dropout0Spec :: DropoutSpec
dropout1Spec :: DropoutSpec
epsSpec :: Double
..} =
    Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
TransformerMLP
      (Linear embedDim ffnDim dtype device
 -> Linear ffnDim embedDim dtype device
 -> Dropout
 -> Dropout
 -> LayerNorm '[embedDim] dtype device
 -> TransformerMLP embedDim ffnDim dtype device)
-> IO (Linear embedDim ffnDim dtype device)
-> IO
     (Linear ffnDim embedDim dtype device
      -> Dropout
      -> Dropout
      -> LayerNorm '[embedDim] dtype device
      -> TransformerMLP embedDim ffnDim dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LinearSpec embedDim ffnDim dtype device
-> IO (Linear embedDim ffnDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim ffnDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      IO
  (Linear ffnDim embedDim dtype device
   -> Dropout
   -> Dropout
   -> LayerNorm '[embedDim] dtype device
   -> TransformerMLP embedDim ffnDim dtype device)
-> IO (Linear ffnDim embedDim dtype device)
-> IO
     (Dropout
      -> Dropout
      -> LayerNorm '[embedDim] dtype device
      -> TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec ffnDim embedDim dtype device
-> IO (Linear ffnDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec ffnDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
      IO
  (Dropout
   -> Dropout
   -> LayerNorm '[embedDim] dtype device
   -> TransformerMLP embedDim ffnDim dtype device)
-> IO Dropout
-> IO
     (Dropout
      -> LayerNorm '[embedDim] dtype device
      -> TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout0Spec
      IO
  (Dropout
   -> LayerNorm '[embedDim] dtype device
   -> TransformerMLP embedDim ffnDim dtype device)
-> IO Dropout
-> IO
     (LayerNorm '[embedDim] dtype device
      -> TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout1Spec
      IO
  (LayerNorm '[embedDim] dtype device
   -> TransformerMLP embedDim ffnDim dtype device)
-> IO (LayerNorm '[embedDim] dtype device)
-> IO (TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LayerNormSpec '[embedDim] dtype device
-> IO (LayerNorm '[embedDim] dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (Double -> LayerNormSpec '[embedDim] dtype device
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec)

--------------------------------------------------------------------------------
-- Relation-Aware Transformer Layer
--------------------------------------------------------------------------------

data
  TransformerLayerSpec
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLayerSpec ::
    forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
    { forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device,
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
attnDropoutSpec :: DropoutSpec,
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
epsSpec' :: Double,
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
    } ->
    TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
  deriving (Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
[TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
(Int
 -> TransformerLayerSpec
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
 -> ShowS)
-> (TransformerLayerSpec
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
    -> String)
-> ([TransformerLayerSpec
       embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
    -> ShowS)
-> Show
     (TransformerLayerSpec
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
showsPrec :: Int
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
show :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
showList :: [TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
Show, TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
(TransformerLayerSpec
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
 -> TransformerLayerSpec
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
 -> Bool)
-> (TransformerLayerSpec
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
    -> TransformerLayerSpec
         embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
    -> Bool)
-> Eq
     (TransformerLayerSpec
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
== :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
/= :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
Eq)

data
  TransformerLayer
    (embedDim :: Nat)
    (kEmbedDim :: Nat)
    (vEmbedDim :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLayer ::
    forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
    { -- | multi-head attention
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device,
      -- | dropout
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_attnDropout :: Dropout,
      -- | layer norm
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_ln :: LayerNorm '[embedDim] dtype device,
      -- | MLP
      forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
    } ->
    TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
  deriving (Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
[TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
(Int
 -> TransformerLayer
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
 -> ShowS)
-> (TransformerLayer
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
    -> String)
-> ([TransformerLayer
       embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
    -> ShowS)
-> Show
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
showsPrec :: Int
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
show :: TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
[TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
showList :: [TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
Show, (forall x.
 TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
 -> Rep
      (TransformerLayer
         embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
      x)
-> (forall x.
    Rep
      (TransformerLayer
         embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
      x
    -> TransformerLayer
         embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> Generic
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  x
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
     x
forall x.
Rep
  (TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  x
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall x.
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
     x
from :: forall x.
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
     x
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  x
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
to :: forall x.
Rep
  (TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  x
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
Generic, TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
(TransformerLayer
   embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
 -> HList
      (Parameters
         (TransformerLayer
            embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)))
-> (TransformerLayer
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
    -> HList
         (Parameters
            (TransformerLayer
               embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
    -> TransformerLayer
         embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> Parameterized
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
flattenParameters :: TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
replaceParameters :: TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
     (Parameters
        (TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
Parameterized)

transformerLayer ::
  forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat) (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat) dtype device.
  ( 1 <= numHeads,
    embedDim ~ (headDim * numHeads),
    All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
    IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
    KnownDType dtype,
    dtype ~ SumDType dtype,
    StandardFloatingPointDTypeValidation device dtype,
    MatMulDTypeIsValid device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    SumDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  -- | transformer layer model ADT
  TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device ->
  -- | switch between training mode and evaluation mode (turns random dropout on and off)
  Bool ->
  -- | optional attention mask
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
  -- | optional key padding mask
  Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
  -- | optional key relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | optional value relations
  Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
  -- | query representation
  Tensor device dtype '[batchSize, seqLen', embedDim] ->
  -- | key representation
  Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
  -- | value representation
  Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
  -- | transformer layer output representation
  IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer :: forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
       (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
       (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
 KnownDType dtype, dtype ~ SumDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer {LayerNorm '[embedDim] dtype device
Dropout
TransformerMLP embedDim ffnDim dtype device
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
     embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_attnDropout :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_ln :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_mlp :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_mha :: MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_attnDropout :: Dropout
transformerLayer_ln :: LayerNorm '[embedDim] dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value =
  let f :: Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f Tensor device dtype '[batchSize, seqLen', embedDim]
query' =
        MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
     (Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen])
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
       (batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 KnownDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
     (Tensor device dtype '[batchSize, seqLen', embedDim],
      Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query' Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
          IO
  (Tensor device dtype '[batchSize, seqLen', embedDim],
   Tensor device dtype '[batchSize, seqLen', seqLen])
-> ((Tensor device dtype '[batchSize, seqLen', embedDim],
     Tensor device dtype '[batchSize, seqLen', seqLen])
    -> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Dropout
-> Bool
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
transformerLayer_attnDropout Bool
train (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> ((Tensor device dtype '[batchSize, seqLen', embedDim],
     Tensor device dtype '[batchSize, seqLen', seqLen])
    -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim],
    Tensor device dtype '[batchSize, seqLen', seqLen])
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor device dtype '[batchSize, seqLen', embedDim],
 Tensor device dtype '[batchSize, seqLen', seqLen])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall a b. (a, b) -> a
fst
   in (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
      (CheckBroadcast
         '[batchSize, seqLen', embedDim]
         '[batchSize, seqLen', embedDim]
         (ComputeBroadcast
            (ReverseImpl '[batchSize, seqLen', embedDim] '[])
            (ReverseImpl '[batchSize, seqLen', embedDim] '[])))
    -> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall {device :: (DeviceType, Nat)} {dtype :: DType}
       {dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
       {shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid
   device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
 Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
      device
      (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
      (CheckBroadcast
         shape
         shape'
         (ComputeBroadcast
            (ReverseImpl shape '[]) (ReverseImpl shape' '[])))
    -> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f (Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device dtype '[batchSize, seqLen', embedDim]
 -> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
    -> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayerNorm '[embedDim] dtype device
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
transformerLayer_ln) Tensor device dtype '[batchSize, seqLen', embedDim]
query IO (Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
    -> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
       (batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
 StandardFloatingPointDTypeValidation device dtype,
 KnownNat embedDim,
 IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp Bool
train

instance
  ( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, ffnDim],
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  A.Randomizable
    (TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
    (TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
  where
  sample :: TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> IO
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
sample TransformerLayerSpec {Double
DropoutSpec
TransformerMLPSpec embedDim ffnDim dtype device
MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
     embedDim kEmbedDim vEmbedDim numHeads dtype device
attnDropoutSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
epsSpec' :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
mlpSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TransformerLayerSpec
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
mhaSpec :: MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
attnDropoutSpec :: DropoutSpec
epsSpec' :: Double
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
..} =
    MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
       (numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
MultiheadAttention
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
     embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
TransformerLayer
      (MultiheadAttention
   embedDim kEmbedDim vEmbedDim numHeads dtype device
 -> Dropout
 -> LayerNorm '[embedDim] dtype device
 -> TransformerMLP embedDim ffnDim dtype device
 -> TransformerLayer
      embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO
     (Dropout
      -> LayerNorm '[embedDim] dtype device
      -> TransformerMLP embedDim ffnDim dtype device
      -> TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
-> IO
     (MultiheadAttention
        embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample MultiheadAttentionSpec
  embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec
      IO
  (Dropout
   -> LayerNorm '[embedDim] dtype device
   -> TransformerMLP embedDim ffnDim dtype device
   -> TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO Dropout
-> IO
     (LayerNorm '[embedDim] dtype device
      -> TransformerMLP embedDim ffnDim dtype device
      -> TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
attnDropoutSpec
      IO
  (LayerNorm '[embedDim] dtype device
   -> TransformerMLP embedDim ffnDim dtype device
   -> TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO (LayerNorm '[embedDim] dtype device)
-> IO
     (TransformerMLP embedDim ffnDim dtype device
      -> TransformerLayer
           embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LayerNormSpec '[embedDim] dtype device
-> IO (LayerNorm '[embedDim] dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (Double -> LayerNormSpec '[embedDim] dtype device
forall (normalizedShape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec')
      IO
  (TransformerMLP embedDim ffnDim dtype device
   -> TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO (TransformerMLP embedDim ffnDim dtype device)
-> IO
     (TransformerLayer
        embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> TransformerMLPSpec embedDim ffnDim dtype device
-> IO (TransformerMLP embedDim ffnDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec

--------------------------------------------------------------------------------
-- Transformer Language Model (GPT-2)
--------------------------------------------------------------------------------

data
  TransformerLMSpec
    (numAttnLayers :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (paddingIdx :: Nat)
    (numEmbeds :: Nat)
    (embedDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLMSpec ::
    forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
    { -- | dropout spec
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> DropoutSpec
lmDropoutSpec :: DropoutSpec,
      -- | spec for each and every transformer layer
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLayerSpec
     embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec :: TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device
    } ->
    TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
  deriving (Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
[TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
(Int
 -> TransformerLMSpec
      numAttnLayers
      numHeads
      ffnDim
      paddingIdx
      numEmbeds
      embedDim
      dtype
      device
 -> ShowS)
-> (TransformerLMSpec
      numAttnLayers
      numHeads
      ffnDim
      paddingIdx
      numEmbeds
      embedDim
      dtype
      device
    -> String)
-> ([TransformerLMSpec
       numAttnLayers
       numHeads
       ffnDim
       paddingIdx
       numEmbeds
       embedDim
       dtype
       device]
    -> ShowS)
-> Show
     (TransformerLMSpec
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
showsPrec :: Int
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> ShowS
$cshow :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
show :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> String
$cshowList :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
showList :: [TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device]
-> ShowS
Show, TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
(TransformerLMSpec
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device
 -> TransformerLMSpec
      numAttnLayers
      numHeads
      ffnDim
      paddingIdx
      numEmbeds
      embedDim
      dtype
      device
 -> Bool)
-> (TransformerLMSpec
      numAttnLayers
      numHeads
      ffnDim
      paddingIdx
      numEmbeds
      embedDim
      dtype
      device
    -> TransformerLMSpec
         numAttnLayers
         numHeads
         ffnDim
         paddingIdx
         numEmbeds
         embedDim
         dtype
         device
    -> Bool)
-> Eq
     (TransformerLMSpec
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
== :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
$c/= :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
/= :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLMSpec
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
-> Bool
Eq)

data
  TransformerLM
    (numAttnLayers :: Nat)
    (numHeads :: Nat)
    (ffnDim :: Nat)
    (paddingIdx :: Nat)
    (numEmbeds :: Nat)
    (embedDim :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  TransformerLM ::
    forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
    { -- | token embedding
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding
     ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding :: Embedding ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device,
      -- | positional embedding
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device,
      -- | transformer dropout
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Dropout
tDropout :: Dropout,
      -- | transformer layers
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers :: HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)),
      -- | final output projection
      forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Linear embedDim numEmbeds dtype device
tProj :: Linear embedDim numEmbeds dtype device
    } ->
    TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
  deriving ((forall x.
 TransformerLM
   numAttnLayers
   numHeads
   ffnDim
   paddingIdx
   numEmbeds
   embedDim
   dtype
   device
 -> Rep
      (TransformerLM
         numAttnLayers
         numHeads
         ffnDim
         paddingIdx
         numEmbeds
         embedDim
         dtype
         device)
      x)
-> (forall x.
    Rep
      (TransformerLM
         numAttnLayers
         numHeads
         ffnDim
         paddingIdx
         numEmbeds
         embedDim
         dtype
         device)
      x
    -> TransformerLM
         numAttnLayers
         numHeads
         ffnDim
         paddingIdx
         numEmbeds
         embedDim
         dtype
         device)
-> Generic
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device)
  x
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Rep
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
     x
forall x.
Rep
  (TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device)
  x
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
forall x.
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Rep
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
     x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Rep
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
     x
from :: forall x.
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Rep
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
     x
$cto :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
  (TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device)
  x
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
to :: forall x.
Rep
  (TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device)
  x
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
Generic)

deriving instance
  ( Show
      ( HList
          ( HReplicateR
              numAttnLayers
              ( TransformerLayer
                  embedDim
                  embedDim
                  embedDim
                  numHeads
                  ffnDim
                  dtype
                  device
              )
          )
      )
  ) =>
  Show (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)

instance
  ( layers
      ~ ( HReplicateR
            numAttnLayers
            ( TransformerLayer
                embedDim
                embedDim
                embedDim
                numHeads
                ffnDim
                dtype
                device
            )
        ),
    Parameterized
      ( HList
          layers
      ),
    HAppendFD
      (Parameters (HList layers))
      '[ Parameter device dtype '[numEmbeds, embedDim],
         Parameter device dtype '[numEmbeds]
       ]
      ( Parameters (HList layers)
          ++ '[ Parameter device dtype '[numEmbeds, embedDim],
                Parameter device dtype '[numEmbeds]
              ]
      )
  ) =>
  Parameterized (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)

data
  FoldLayers
    (batchSize :: Nat)
    (seqLen :: Nat)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat)) = FoldLayers
  { -- | switch between training mode and evaluation mode (turns random dropout on and off)
    forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
flTrain :: Bool,
    -- | optional attention mask
    forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]),
    -- | optional key padding mask
    forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'D.Bool '[batchSize, seqLen])
  }

instance
  ( 1 <= numHeads,
    embedDim ~ (headDim * numHeads),
    All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim],
    IsSuffixOf '[embedDim] '[batchSize, seqLen, embedDim],
    KnownDType dtype,
    StandardFloatingPointDTypeValidation device dtype,
    MatMulDTypeIsValid device dtype,
    BasicArithmeticDTypeIsValid device dtype,
    dtype ~ SumDType dtype,
    SumDTypeIsValid device dtype,
    KnownDevice device
  ) =>
  Apply'
    (FoldLayers batchSize seqLen dtype device)
    ( TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device,
      IO (Tensor device dtype '[batchSize, seqLen, embedDim])
    )
    (IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
  where
  apply' :: FoldLayers batchSize seqLen dtype device
-> (TransformerLayer
      embedDim embedDim embedDim numHeads ffnDim dtype device,
    IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
apply' FoldLayers {Bool
Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
Maybe (Tensor device 'Bool '[batchSize, seqLen])
flTrain :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
flAttentionMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flKeyPaddingMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flTrain :: Bool
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
..} (TransformerLayer
  embedDim embedDim embedDim numHeads ffnDim dtype device
layer, IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx) = IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx IO (Tensor device dtype '[batchSize, seqLen, embedDim])
-> (Tensor device dtype '[batchSize, seqLen, embedDim]
    -> IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Tensor device dtype '[batchSize, seqLen, embedDim]
x -> TransformerLayer
  embedDim embedDim embedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
       (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
       (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
 All
   KnownNat
   '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
     batchSize, headDim],
 IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
 KnownDType dtype, dtype ~ SumDType dtype,
 StandardFloatingPointDTypeValidation device dtype,
 MatMulDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype,
 SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
  embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
     (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer
  embedDim embedDim embedDim numHeads ffnDim dtype device
layer Bool
flTrain Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
forall a. Maybe a
Nothing Maybe (Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
forall a. Maybe a
Nothing Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x

transformerLM ::
  forall
    numAttnLayers
    numHeads
    ffnDim
    paddingIdx
    numEmbeds
    embedDim
    seqLen
    batchSize
    dtype
    device.
  ( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
    paddingIdx + 1 <= numEmbeds,
    1 <= seqLen,
    HFoldrM
      IO
      (FoldLayers batchSize seqLen dtype device)
      (Tensor device dtype '[batchSize, seqLen, embedDim])
      (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
      (Tensor device dtype '[batchSize, seqLen, embedDim]),
    BasicArithmeticDTypeIsValid device dtype,
    ComparisonDTypeIsValid device dtype,
    ComparisonDTypeIsValid device 'D.Int64,
    KnownDType dtype,
    KnownDevice device
  ) =>
  TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device ->
  Bool ->
  Tensor device 'D.Int64 '[batchSize, seqLen] ->
  IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
 (paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
 HFoldrM
   IO
   (FoldLayers batchSize seqLen dtype device)
   (Tensor device dtype '[batchSize, seqLen, embedDim])
   (HReplicateR
      numAttnLayers
      (TransformerLayer
         embedDim embedDim embedDim numHeads ffnDim dtype device))
   (Tensor device dtype '[batchSize, seqLen, embedDim]),
 BasicArithmeticDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
 KnownDevice device) =>
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM {HList
  (HReplicateR
     numAttnLayers
     (TransformerLayer
        embedDim embedDim embedDim numHeads ffnDim dtype device))
Embedding 'Nothing 2048 embedDim 'Constant dtype device
Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
Linear embedDim numEmbeds dtype device
Dropout
tEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding
     ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tPosEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tDropout :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Dropout
tLayers :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
tProj :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Linear embedDim numEmbeds dtype device
tEmbedding :: Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device
tDropout :: Dropout
tLayers :: HList
  (HReplicateR
     numAttnLayers
     (TransformerLayer
        embedDim embedDim embedDim numHeads ffnDim dtype device))
tProj :: Linear embedDim numEmbeds dtype device
..} Bool
train Tensor device 'Int64 '[batchSize, seqLen]
xTokens = do
  let x :: Tensor device dtype '[batchSize, seqLen, embedDim]
x = Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding Tensor device 'Int64 '[batchSize, seqLen]
xTokens
      positions :: Tensor device dtype '[batchSize, seqLen, embedDim]
positions =
        forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand @'[batchSize, seqLen, embedDim] Bool
True
          (Tensor
   device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim])
 -> Tensor device dtype '[batchSize, seqLen, embedDim])
-> (Int
    -> Tensor
         device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim]))
-> Int
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Tensor device 'Int64 '[seqLen]
-> Tensor
     device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim])
forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding
          (Tensor device 'Int64 '[seqLen]
 -> Tensor
      device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim]))
-> (Int -> Tensor device 'Int64 '[seqLen])
-> Int
-> Tensor
     device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Int64
          (Tensor device 'Float '[seqLen] -> Tensor device 'Int64 '[seqLen])
-> (Int -> Tensor device 'Float '[seqLen])
-> Int
-> Tensor device 'Int64 '[seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @seqLen (Int
0 :: Int)
          (Int -> Tensor device dtype '[batchSize, seqLen, embedDim])
-> Int -> Tensor device dtype '[batchSize, seqLen, embedDim]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(seqLen - 1)
  Tensor device dtype '[batchSize, seqLen, embedDim]
x' <- Dropout
-> Bool
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
tDropout Bool
train (Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype '[batchSize, seqLen, embedDim]
positions)
  let attentionMask :: Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask =
        forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @0
          (Tensor device 'Bool '[seqLen, seqLen]
 -> Tensor device 'Bool '[1, seqLen, seqLen])
-> (Tensor device 'Int8 '[seqLen, seqLen]
    -> Tensor device 'Bool '[seqLen, seqLen])
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[1, seqLen, seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Bool
          (Tensor device 'Int8 '[seqLen, seqLen]
 -> Tensor device 'Bool '[seqLen, seqLen])
-> (Tensor device 'Int8 '[seqLen, seqLen]
    -> Tensor device 'Int8 '[seqLen, seqLen])
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[seqLen, seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Int8 '[seqLen, seqLen]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
triu Int
1
          (Tensor device 'Int8 '[seqLen, seqLen]
 -> Tensor device 'Bool '[1, seqLen, seqLen])
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[1, seqLen, seqLen]
forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[seqLen, seqLen] @D.Int8 @device
      attentionMask' :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' =
        Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
forall a. a -> Maybe a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device dtype '[batchSize, seqLen, seqLen]
 -> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]))
-> (Tensor device dtype '[batchSize, seqLen, seqLen]
    -> Tensor device dtype '[batchSize, seqLen, seqLen])
-> Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor device 'Bool '[1, seqLen, seqLen]
-> Double
-> Tensor device dtype '[batchSize, seqLen, seqLen]
-> Tensor device dtype '[batchSize, seqLen, seqLen]
forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask (-Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) (Tensor device dtype '[batchSize, seqLen, seqLen]
 -> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]))
-> Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
forall a b. (a -> b) -> a -> b
$
          forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[batchSize, seqLen, seqLen] @dtype @device
  let keyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask = Tensor device 'Bool '[batchSize, seqLen]
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
forall a. a -> Maybe a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device 'Bool '[batchSize, seqLen]
 -> Maybe (Tensor device 'Bool '[batchSize, seqLen]))
-> Tensor device 'Bool '[batchSize, seqLen]
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
forall a b. (a -> b) -> a -> b
$ Tensor device 'Int64 '[batchSize, seqLen]
xTokens Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device 'Int64 '[]
-> Tensor device 'Bool '[batchSize, seqLen]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
==. (Integer -> Tensor device 'Int64 '[]
forall a. Num a => Integer -> a
fromInteger (Integer -> Tensor device 'Int64 '[])
-> (Proxy paddingIdx -> Integer)
-> Proxy paddingIdx
-> Tensor device 'Int64 '[]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy paddingIdx -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy paddingIdx -> Tensor device 'Int64 '[])
-> Proxy paddingIdx -> Tensor device 'Int64 '[]
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @paddingIdx :: Tensor device 'D.Int64 '[])
  Tensor device dtype '[batchSize, seqLen, embedDim]
y <- FoldLayers batchSize seqLen dtype device
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM (Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> FoldLayers batchSize seqLen dtype device
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> FoldLayers batchSize seqLen dtype device
FoldLayers Bool
train Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask) Tensor device dtype '[batchSize, seqLen, embedDim]
x' HList
  (HReplicateR
     numAttnLayers
     (TransformerLayer
        embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers
  Tensor device dtype '[batchSize, seqLen, numEmbeds]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Tensor device dtype '[batchSize, seqLen, numEmbeds]
 -> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds]))
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall a b. (a -> b) -> a -> b
$ Linear embedDim numEmbeds dtype device
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim numEmbeds dtype device
tProj Tensor device dtype '[batchSize, seqLen, embedDim]
y

instance
  ( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
    paddingIdx + 1 <= numEmbeds,
    1 <= seqLen,
    HFoldrM
      IO
      (FoldLayers batchSize seqLen dtype device)
      (Tensor device dtype '[batchSize, seqLen, embedDim])
      (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
      (Tensor device dtype '[batchSize, seqLen, embedDim]),
    BasicArithmeticDTypeIsValid device dtype,
    ComparisonDTypeIsValid device dtype,
    ComparisonDTypeIsValid device 'D.Int64,
    KnownDType dtype,
    KnownDevice device
  ) =>
  HasForward (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (Tensor device 'D.Int64 '[batchSize, seqLen]) (Tensor device dtype '[batchSize, seqLen, numEmbeds])
  where
  forward :: TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forward TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Tensor device 'Int64 '[batchSize, seqLen]
input = IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forall a. IO a -> a
unsafePerformIO (IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
 -> Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forall a b. (a -> b) -> a -> b
$ TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
 (paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
 HFoldrM
   IO
   (FoldLayers batchSize seqLen dtype device)
   (Tensor device dtype '[batchSize, seqLen, embedDim])
   (HReplicateR
      numAttnLayers
      (TransformerLayer
         embedDim embedDim embedDim numHeads ffnDim dtype device))
   (Tensor device dtype '[batchSize, seqLen, embedDim]),
 BasicArithmeticDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
 KnownDevice device) =>
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Bool
False Tensor device 'Int64 '[batchSize, seqLen]
input
  forwardStoch :: TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forwardStoch TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Tensor device 'Int64 '[batchSize, seqLen]
input = TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
 (paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
 HFoldrM
   IO
   (FoldLayers batchSize seqLen dtype device)
   (Tensor device dtype '[batchSize, seqLen, embedDim])
   (HReplicateR
      numAttnLayers
      (TransformerLayer
         embedDim embedDim embedDim numHeads ffnDim dtype device))
   (Tensor device dtype '[batchSize, seqLen, embedDim]),
 BasicArithmeticDTypeIsValid device dtype,
 ComparisonDTypeIsValid device dtype,
 ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
 KnownDevice device) =>
TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
model Bool
True Tensor device 'Int64 '[batchSize, seqLen]
input

sinusoidal ::
  forall numEmbeds embedDim device.
  ( All KnownNat '[numEmbeds, embedDim],
    1 <= numEmbeds,
    1 <= Div embedDim 2,
    (Div embedDim 2 * 2) ~ embedDim,
    StandardFloatingPointDTypeValidation device 'D.Float,
    BasicArithmeticDTypeIsValid device 'D.Float,
    KnownDevice device
  ) =>
  Tensor device 'D.Float '[numEmbeds, embedDim]
sinusoidal :: forall (numEmbeds :: Nat) (embedDim :: Nat)
       (device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
 1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
 StandardFloatingPointDTypeValidation device 'Float,
 BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal =
  let positions :: Tensor device 'Float '[numEmbeds, 1]
positions =
        forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1
          (Tensor device 'Float '[numEmbeds]
 -> Tensor device 'Float '[numEmbeds, 1])
-> (Int -> Tensor device 'Float '[numEmbeds])
-> Int
-> Tensor device 'Float '[numEmbeds, 1]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @numEmbeds (Int
0 :: Int)
          (Int -> Tensor device 'Float '[numEmbeds, 1])
-> Int -> Tensor device 'Float '[numEmbeds, 1]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(numEmbeds - 1)
      scalingFactors :: Tensor device 'Float '[Div embedDim 2]
scalingFactors =
        Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[Div embedDim 2]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp
          (Tensor device 'Float '[Div embedDim 2]
 -> Tensor device 'Float '[Div embedDim 2])
-> (Int -> Tensor device 'Float '[Div embedDim 2])
-> Int
-> Tensor device 'Float '[Div embedDim 2]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double
-> Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[Div embedDim 2]
forall a (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (- Double -> Double
forall a. Floating a => a -> a
log (Double
10000 :: Double) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Integer -> Double
forall a. Num a => Integer -> a
fromInteger (Integer -> Double)
-> (Proxy (Div embedDim 2) -> Integer)
-> Proxy (Div embedDim 2)
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy (Div embedDim 2) -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy (Div embedDim 2) -> Double)
-> Proxy (Div embedDim 2) -> Double
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @(Div embedDim 2)))
          (Tensor device 'Float '[Div embedDim 2]
 -> Tensor device 'Float '[Div embedDim 2])
-> (Int -> Tensor device 'Float '[Div embedDim 2])
-> Int
-> Tensor device 'Float '[Div embedDim 2]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
 TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @(Div embedDim 2) (Int
0 :: Int)
          (Int -> Tensor device 'Float '[Div embedDim 2])
-> Int -> Tensor device 'Float '[Div embedDim 2]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @((Div embedDim 2) - 1)
      radians :: Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians = Tensor device 'Float '[numEmbeds, 1]
-> Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[numEmbeds, Div embedDim 2]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
       (dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
       (device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
 shape'' ~ Broadcast shape shape',
 BasicArithmeticDTypeIsValid device dtype,
 BasicArithmeticDTypeIsValid device dtype',
 BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device 'Float '[numEmbeds, 1]
positions Tensor device 'Float '[Div embedDim 2]
scalingFactors
      weights :: Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
stack @2 (Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> Tensor device 'Float '[numEmbeds, Div embedDim 2]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
sin Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> HList '[Tensor device 'Float '[numEmbeds, Div embedDim 2]]
-> HList
     '[Tensor device 'Float '[numEmbeds, Div embedDim 2],
       Tensor device 'Float '[numEmbeds, Div embedDim 2]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> Tensor device 'Float '[numEmbeds, Div embedDim 2]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
 IsUnnamed t device dtype shape) =>
t -> t
cos Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> HList '[]
-> HList '[Tensor device 'Float '[numEmbeds, Div embedDim 2]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil)
   in Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
-> Tensor device 'Float '[numEmbeds, embedDim]
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights

instance
  ( paddingIdx <= numEmbeds,
    1 <= numEmbeds - paddingIdx,
    1 <= Div embedDim 2,
    (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
    (Div embedDim 2 * 2) ~ embedDim,
    All KnownNat '[ffnDim, paddingIdx, numEmbeds, embedDim],
    HReplicate numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device),
    A.Randomizable
      (HList (HReplicateR numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device)))
      (HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))),
    KnownDType dtype,
    RandDTypeIsValid device dtype,
    StandardFloatingPointDTypeValidation device 'D.Float,
    BasicArithmeticDTypeIsValid device 'D.Float,
    KnownDevice device
  ) =>
  A.Randomizable
    (TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
    (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
  where
  sample :: TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> IO
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
sample TransformerLMSpec {DropoutSpec
TransformerLayerSpec
  embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> DropoutSpec
lmLayerSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
  numAttnLayers
  numHeads
  ffnDim
  paddingIdx
  numEmbeds
  embedDim
  dtype
  device
-> TransformerLayerSpec
     embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: DropoutSpec
lmLayerSpec :: TransformerLayerSpec
  embedDim embedDim embedDim numHeads ffnDim dtype device
..} =
    Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
       (paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
     numAttnLayers
     numHeads
     ffnDim
     paddingIdx
     numEmbeds
     embedDim
     dtype
     device
TransformerLM
      (Embedding
   ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
 -> Embedding 'Nothing 2048 embedDim 'Constant dtype device
 -> Dropout
 -> HList
      (HReplicateR
         numAttnLayers
         (TransformerLayer
            embedDim embedDim embedDim numHeads ffnDim dtype device))
 -> Linear embedDim numEmbeds dtype device
 -> TransformerLM
      numAttnLayers
      numHeads
      ffnDim
      paddingIdx
      numEmbeds
      embedDim
      dtype
      device)
-> IO
     (Embedding
        ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device)
-> IO
     (Embedding 'Nothing 2048 embedDim 'Constant dtype device
      -> Dropout
      -> HList
           (HReplicateR
              numAttnLayers
              (TransformerLayer
                 embedDim embedDim embedDim numHeads ffnDim dtype device))
      -> Linear embedDim numEmbeds dtype device
      -> TransformerLM
           numAttnLayers
           numHeads
           ffnDim
           paddingIdx
           numEmbeds
           embedDim
           dtype
           device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> EmbeddingSpec
  ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> IO
     (Embedding
        ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx))
      IO
  (Embedding 'Nothing 2048 embedDim 'Constant dtype device
   -> Dropout
   -> HList
        (HReplicateR
           numAttnLayers
           (TransformerLayer
              embedDim embedDim embedDim numHeads ffnDim dtype device))
   -> Linear embedDim numEmbeds dtype device
   -> TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
-> IO (Embedding 'Nothing 2048 embedDim 'Constant dtype device)
-> IO
     (Dropout
      -> HList
           (HReplicateR
              numAttnLayers
              (TransformerLayer
                 embedDim embedDim embedDim numHeads ffnDim dtype device))
      -> Linear embedDim numEmbeds dtype device
      -> TransformerLM
           numAttnLayers
           numHeads
           ffnDim
           paddingIdx
           numEmbeds
           embedDim
           dtype
           device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> EmbeddingSpec 'Nothing 2048 embedDim 'Constant dtype device
-> IO (Embedding 'Nothing 2048 embedDim 'Constant dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> EmbeddingSpec
     paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbeddingSpec @'Nothing (Tensor device 'Float '[2048, embedDim]
-> Tensor device dtype '[2048, embedDim]
forall (dtype' :: DType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
 t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType Tensor device 'Float '[2048, embedDim]
forall (numEmbeds :: Nat) (embedDim :: Nat)
       (device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
 1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
 StandardFloatingPointDTypeValidation device 'Float,
 BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal))
      IO
  (Dropout
   -> HList
        (HReplicateR
           numAttnLayers
           (TransformerLayer
              embedDim embedDim embedDim numHeads ffnDim dtype device))
   -> Linear embedDim numEmbeds dtype device
   -> TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
-> IO Dropout
-> IO
     (HList
        (HReplicateR
           numAttnLayers
           (TransformerLayer
              embedDim embedDim embedDim numHeads ffnDim dtype device))
      -> Linear embedDim numEmbeds dtype device
      -> TransformerLM
           numAttnLayers
           numHeads
           ffnDim
           paddingIdx
           numEmbeds
           embedDim
           dtype
           device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
lmDropoutSpec
      IO
  (HList
     (HReplicateR
        numAttnLayers
        (TransformerLayer
           embedDim embedDim embedDim numHeads ffnDim dtype device))
   -> Linear embedDim numEmbeds dtype device
   -> TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
-> IO
     (HList
        (HReplicateR
           numAttnLayers
           (TransformerLayer
              embedDim embedDim embedDim numHeads ffnDim dtype device)))
-> IO
     (Linear embedDim numEmbeds dtype device
      -> TransformerLM
           numAttnLayers
           numHeads
           ffnDim
           paddingIdx
           numEmbeds
           embedDim
           dtype
           device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> HList
  (HReplicateR
     numAttnLayers
     (TransformerLayerSpec
        embedDim embedDim embedDim numHeads ffnDim dtype device))
-> IO
     (HList
        (HReplicateR
           numAttnLayers
           (TransformerLayer
              embedDim embedDim embedDim numHeads ffnDim dtype device)))
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (n :: Nat) e. HReplicate n e => e -> HList (HReplicateR n e)
hreplicate @numAttnLayers TransformerLayerSpec
  embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec)
      IO
  (Linear embedDim numEmbeds dtype device
   -> TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
-> IO (Linear embedDim numEmbeds dtype device)
-> IO
     (TransformerLM
        numAttnLayers
        numHeads
        ffnDim
        paddingIdx
        numEmbeds
        embedDim
        dtype
        device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec embedDim numEmbeds dtype device
-> IO (Linear embedDim numEmbeds dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim numEmbeds dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
       (dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec