{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}
module Torch.Typed.NN.Transformer where
import Control.Monad
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import System.IO.Unsafe (unsafePerformIO)
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..))
import qualified Torch.NN as A
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional hiding (linear, log)
import Torch.Typed.NN.Dropout
import Torch.Typed.NN.Linear
import Torch.Typed.NN.Normalization
import Torch.Typed.NN.Sparse
import Torch.Typed.Parameter
import Torch.Typed.Tensor
import Prelude hiding (cos, exp, sin)
residual :: (Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g Tensor device dtype shape
x = Tensor device dtype shape -> m (Tensor device dtype' shape')
f Tensor device dtype shape
x m (Tensor device dtype' shape')
-> (Tensor device dtype' shape' -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\Tensor device dtype' shape'
x' -> Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b
g (Tensor device dtype shape
x Tensor device dtype shape
-> Tensor device dtype' shape'
-> Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype' shape'
x'))
data
MultiheadAttentionSpec
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
MultiheadAttentionSpec ::
DropoutSpec ->
MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device
deriving (Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
[MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
(Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS)
-> (MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String)
-> ([MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS)
-> Show
(MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
showsPrec :: Int
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
show :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
showList :: [MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
Show, MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
(MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool)
-> (MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool)
-> Eq
(MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
== :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
/= :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
Eq)
data
MultiheadAttention
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
MultiheadAttention ::
{
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaQInProj :: Linear embedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaOutProj :: Linear embedDim embedDim dtype device,
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaDropout :: Dropout
} ->
MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device
deriving (Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
[MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
(Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS)
-> (MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String)
-> ([MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS)
-> Show
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
showsPrec :: Int
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
show :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
[MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
showList :: [MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device]
-> ShowS
Show, (forall x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x)
-> (forall x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> Generic
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
forall x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
forall x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
from :: forall x.
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
to :: forall x.
Rep
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
x
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
Generic, MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)))
-> (MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> Parameterized
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
flattenParameters :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
replaceParameters :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> HList
(Parameters
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device))
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
Parameterized)
multiheadAttention ::
forall embedDim kEmbedDim vEmbedDim numHeads seqLen seqLen' batchSize headDim dtype device.
( 1 <= numHeads,
embedDim ~ (headDim * numHeads),
All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
dtype ~ SumDType dtype,
SumDTypeIsValid device dtype,
KnownDevice device
) =>
MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device ->
Bool ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Tensor device dtype '[batchSize, seqLen', embedDim] ->
Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
IO
( Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen]
)
multiheadAttention :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
(batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention {Linear embedDim embedDim dtype device
Linear kEmbedDim embedDim dtype device
Linear vEmbedDim embedDim dtype device
Dropout
mhaQInProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaKInProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear kEmbedDim embedDim dtype device
mhaVInProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear vEmbedDim embedDim dtype device
mhaOutProj :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Linear embedDim embedDim dtype device
mhaDropout :: forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
mhaQInProj :: Linear embedDim embedDim dtype device
mhaKInProj :: Linear kEmbedDim embedDim dtype device
mhaVInProj :: Linear vEmbedDim embedDim dtype device
mhaOutProj :: Linear embedDim embedDim dtype device
mhaDropout :: Dropout
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value = do
Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights <-
Dropout
-> Bool
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
mhaDropout Bool
train
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))))
-> (Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownNat dim, DimOutOfBoundCheck shape dim, KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype) =>
Tensor device dtype shape -> Tensor device dtype shape
softmax @3
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> (Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> (Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> IO
(Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[]))))
forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
-> IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights, Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights)
where
_attentionWeights :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_attentionWeights =
let scaling :: Double
scaling = Double -> Double
forall a. Floating a => a -> a
Prelude.sqrt (Double -> Double) -> (Int -> Double) -> Int -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @headDim :: Double
q :: Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q = Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Double
scaling (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear embedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaQInProj (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen', embedDim]
query
k :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k = Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' (Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> (Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim])
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear kEmbedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear kEmbedDim embedDim dtype device
mhaKInProj (Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key
weights :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights = Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
-> Tensor device dtype '[batchSize, numHeads, headDim, seqLen]
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
k)
weights' :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations of
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights
Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
'[batchSize, seqLen', numHeads, headDim]
'[batchSize, seqLen', headDim, seqLen]
(ComputeMatMul
(ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
(ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
1
(GetValue
(CheckMatMul
'[batchSize, seqLen', numHeads, headDim]
'[batchSize, seqLen', headDim, seqLen]
(ComputeMatMul
(ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
(ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
2))
2
(GetValue
(CheckMatMul
'[batchSize, seqLen', numHeads, headDim]
'[batchSize, seqLen', headDim, seqLen]
(ComputeMatMul
(ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
(ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
1))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 ((forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor device dtype '[batchSize, numHeads, seqLen', headDim]
q) Tensor device dtype '[batchSize, seqLen', numHeads, headDim]
-> Tensor device dtype '[batchSize, seqLen', headDim, seqLen]
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, seqLen', numHeads, headDim]
'[batchSize, seqLen', headDim, seqLen]
(ComputeMatMul
(ReverseImpl '[batchSize, seqLen', numHeads, headDim] '[])
(ReverseImpl '[batchSize, seqLen', headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
`matmul` (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @2 @3 Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
kr))
in Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
weights'
_maskAttention :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskAttention Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask of
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
Nothing -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
Just Tensor device dtype '[batchSize, seqLen', seqLen]
am -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, 1, seqLen', seqLen]
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 Tensor device dtype '[batchSize, seqLen', seqLen]
am
_maskKeyPaddings :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
_maskKeyPaddings Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
case Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask of
Maybe (Tensor device 'Bool '[batchSize, seqLen])
Nothing -> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
Just Tensor device 'Bool '[batchSize, seqLen]
kpm ->
let keyPaddingMask' :: Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' = forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @2 (Tensor device 'Bool '[batchSize, 1, seqLen]
-> Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2)))
-> (Tensor device 'Bool '[batchSize, seqLen]
-> Tensor device 'Bool '[batchSize, 1, seqLen])
-> Tensor device 'Bool '[batchSize, seqLen]
-> Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1 (Tensor device 'Bool '[batchSize, seqLen]
-> Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2)))
-> Tensor device 'Bool '[batchSize, seqLen]
-> Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
forall a b. (a -> b) -> a -> b
$ Tensor device 'Bool '[batchSize, seqLen]
kpm
in Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
-> Double
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor
device
'Bool
(UnsqueezeCheck
'[batchSize, 1, seqLen]
2
(UnsqueezeImpl '[batchSize, 1, seqLen] 2))
keyPaddingMask' (-Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights
_attention :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
_attention Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights =
let v :: Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v = Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' (Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> (Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim])
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear vEmbedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear vEmbedDim embedDim dtype device
mhaVInProj (Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
forall a b. (a -> b) -> a -> b
$ Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
attention :: Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 (Tensor
device
dtype
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
-> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1)))
-> Tensor
device
dtype
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
-> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
-> Tensor
device
dtype
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights Tensor device dtype '[batchSize, numHeads, seqLen, headDim]
v
attention' :: Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention' = case Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations of
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
Nothing -> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention
Just Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr -> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
-> Tensor
device
dtype
(CheckMatMul
(SetValue
(SetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
2))
2
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1))
'[batchSize, seqLen', seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(SetValue
(SetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
2))
2
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1))
'[])
'[headDim, seqLen, seqLen', batchSize]))
-> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` (Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
2))
2
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1))
-> Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
-> Tensor
device
dtype
(CheckMatMul
(SetValue
(SetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
2))
2
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1))
'[batchSize, seqLen', seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(SetValue
(SetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
2))
2
(GetValue
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
1))
'[])
'[headDim, seqLen, seqLen', batchSize]))
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ MatMul shape shape', MatMulDTypeIsValid device dtype) =>
Tensor device dtype shape
-> Tensor device dtype shape' -> Tensor device dtype shape''
matmul (forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
attentionWeights) Tensor device dtype '[batchSize, seqLen', seqLen, headDim]
vr)
in Linear embedDim embedDim dtype device
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim embedDim dtype device
mhaOutProj (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen', embedDim] (Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall a b. (a -> b) -> a -> b
$ Tensor
device
dtype
(SetValue
(SetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
2))
2
(GetValue
(CheckMatMul
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[batchSize, numHeads, seqLen, headDim]
(ComputeMatMul
(ReverseImpl
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
'[])
'[headDim, seqLen, numHeads, batchSize]))
1))
attention'
averageOverHeads :: Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
averageOverHeads =
let numHeads' :: Int
numHeads' = forall (n :: Nat). KnownNat n => Int
natValI @numHeads
in Int
-> Tensor device dtype '[batchSize, seqLen', seqLen]
-> Tensor device dtype '[batchSize, seqLen', seqLen]
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
divScalar Int
numHeads' (Tensor device dtype '[batchSize, seqLen', seqLen]
-> Tensor device dtype '[batchSize, seqLen', seqLen])
-> (Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen])
-> Tensor
device
dtype
(CheckMatMul
'[batchSize, numHeads, seqLen', headDim]
'[batchSize, numHeads, headDim, seqLen]
(ComputeMatMul
'[headDim, seqLen', numHeads, batchSize]
(ReverseImpl '[batchSize, numHeads, headDim, seqLen] '[])))
-> Tensor device dtype '[batchSize, seqLen', seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (d :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(KnownNat d, shape' ~ DropValue shape d,
SumDTypeIsValid device dtype, dtype' ~ SumDType dtype) =>
Tensor device dtype shape -> Tensor device dtype' shape'
sumDim @1
reshape' ::
forall seqLen''.
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim] ->
Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' :: forall (seqLen'' :: Nat).
KnownNat seqLen'' =>
Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
reshape' = forall (n :: Nat) (m :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat n, KnownNat m, shape' ~ Transpose shape n m) =>
Tensor device dtype shape -> Tensor device dtype shape'
transpose @1 @2 (Tensor device dtype '[batchSize, seqLen'', numHeads, headDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim])
-> (Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, seqLen'', numHeads, headDim])
-> Tensor device dtype '[batchSize, seqLen'', embedDim]
-> Tensor device dtype '[batchSize, numHeads, seqLen'', headDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape @'[batchSize, seqLen'', numHeads, headDim]
instance
( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads],
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device)
(MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device)
where
sample :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> IO
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
sample (MultiheadAttentionSpec DropoutSpec
mhaDropoutSpec) =
Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
forall (embedDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat).
Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
MultiheadAttention
(Linear embedDim embedDim dtype device
-> Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear embedDim embedDim dtype device)
-> IO
(Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LinearSpec embedDim embedDim dtype device
-> IO (Linear embedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
IO
(Linear kEmbedDim embedDim dtype device
-> Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear kEmbedDim embedDim dtype device)
-> IO
(Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec kEmbedDim embedDim dtype device
-> IO (Linear kEmbedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec kEmbedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
IO
(Linear vEmbedDim embedDim dtype device
-> Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear vEmbedDim embedDim dtype device)
-> IO
(Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec vEmbedDim embedDim dtype device
-> IO (Linear vEmbedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec vEmbedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
IO
(Linear embedDim embedDim dtype device
-> Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO (Linear embedDim embedDim dtype device)
-> IO
(Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec embedDim embedDim dtype device
-> IO (Linear embedDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
IO
(Dropout
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO Dropout
-> IO
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
mhaDropoutSpec
data
TransformerMLPSpec
(embedDim :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerMLPSpec ::
forall embedDim ffnDim dtype device.
{
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout0Spec :: DropoutSpec,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout1Spec :: DropoutSpec,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
epsSpec :: Double
} ->
TransformerMLPSpec embedDim ffnDim dtype device
deriving (Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
TransformerMLPSpec embedDim ffnDim dtype device -> String
(Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS)
-> (TransformerMLPSpec embedDim ffnDim dtype device -> String)
-> ([TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS)
-> Show (TransformerMLPSpec embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
showsPrec :: Int -> TransformerMLPSpec embedDim ffnDim dtype device -> ShowS
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> String
show :: TransformerMLPSpec embedDim ffnDim dtype device -> String
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
showList :: [TransformerMLPSpec embedDim ffnDim dtype device] -> ShowS
Show, TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
(TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool)
-> (TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool)
-> Eq (TransformerMLPSpec embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
== :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
$c/= :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
/= :: TransformerMLPSpec embedDim ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device -> Bool
Eq)
data
TransformerMLP
(embedDim :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerMLP ::
forall embedDim ffnDim dtype device.
{
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
linear0 :: Linear embedDim ffnDim dtype device,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
linear1 :: Linear ffnDim embedDim dtype device,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout0 :: Dropout,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout1 :: Dropout,
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
ln :: LayerNorm '[embedDim] dtype device
} ->
TransformerMLP embedDim ffnDim dtype device
deriving (Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
TransformerMLP embedDim ffnDim dtype device -> String
(Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS)
-> (TransformerMLP embedDim ffnDim dtype device -> String)
-> ([TransformerMLP embedDim ffnDim dtype device] -> ShowS)
-> Show (TransformerMLP embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
showsPrec :: Int -> TransformerMLP embedDim ffnDim dtype device -> ShowS
$cshow :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> String
show :: TransformerMLP embedDim ffnDim dtype device -> String
$cshowList :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerMLP embedDim ffnDim dtype device] -> ShowS
showList :: [TransformerMLP embedDim ffnDim dtype device] -> ShowS
Show, (forall x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x)
-> (forall x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device)
-> Generic (TransformerMLP embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
forall x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
forall x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
from :: forall x.
TransformerMLP embedDim ffnDim dtype device
-> Rep (TransformerMLP embedDim ffnDim dtype device) x
$cto :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
to :: forall x.
Rep (TransformerMLP embedDim ffnDim dtype device) x
-> TransformerMLP embedDim ffnDim dtype device
Generic, TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
(TransformerMLP embedDim ffnDim dtype device
-> HList
(Parameters (TransformerMLP embedDim ffnDim dtype device)))
-> (TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device)
-> Parameterized (TransformerMLP embedDim ffnDim dtype device)
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
flattenParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
$creplaceParameters :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
replaceParameters :: TransformerMLP embedDim ffnDim dtype device
-> HList (Parameters (TransformerMLP embedDim ffnDim dtype device))
-> TransformerMLP embedDim ffnDim dtype device
Parameterized)
transformerMLP ::
forall embedDim ffnDim seqLen batchSize dtype device.
( BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype,
KnownNat embedDim,
IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]
) =>
TransformerMLP embedDim ffnDim dtype device ->
Bool ->
Tensor device dtype '[seqLen, batchSize, embedDim] ->
IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP :: forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype,
KnownNat embedDim,
IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP {LayerNorm '[embedDim] dtype device
Linear embedDim ffnDim dtype device
Linear ffnDim embedDim dtype device
Dropout
linear0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear embedDim ffnDim dtype device
linear1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
dropout0 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
dropout1 :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device -> Dropout
ln :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLP embedDim ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
linear0 :: Linear embedDim ffnDim dtype device
linear1 :: Linear ffnDim embedDim dtype device
dropout0 :: Dropout
dropout1 :: Dropout
ln :: LayerNorm '[embedDim] dtype device
..} Bool
train Tensor device dtype '[seqLen, batchSize, embedDim]
input =
(Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> (Tensor
device
(DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
(Broadcast
'[seqLen, batchSize, embedDim] '[seqLen, batchSize, embedDim])
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall {device :: (DeviceType, Nat)} {dtype :: DType}
{dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
{shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid
device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
f (Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> (Tensor
device
(DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
(CheckBroadcast
'[seqLen, batchSize, embedDim]
'[seqLen, batchSize, embedDim]
(ComputeBroadcast
'[embedDim, batchSize, seqLen]
(ReverseImpl '[seqLen, batchSize, embedDim] '[])))
-> Tensor device dtype '[seqLen, batchSize, embedDim])
-> Tensor
device
(DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
(CheckBroadcast
'[seqLen, batchSize, embedDim]
'[seqLen, batchSize, embedDim]
(ComputeBroadcast
'[embedDim, batchSize, seqLen]
(ReverseImpl '[seqLen, batchSize, embedDim] '[])))
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayerNorm '[embedDim] dtype device
-> Tensor
device
(DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
(CheckBroadcast
'[seqLen, batchSize, embedDim]
'[seqLen, batchSize, embedDim]
(ComputeBroadcast
'[embedDim, batchSize, seqLen]
(ReverseImpl '[seqLen, batchSize, embedDim] '[])))
-> Tensor device dtype '[seqLen, batchSize, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
ln) Tensor device dtype '[seqLen, batchSize, embedDim]
input
where
f :: Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
f Tensor device dtype '[seqLen, batchSize, embedDim]
x =
Dropout
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout1 Bool
train
(Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> (Tensor device dtype '[seqLen, batchSize, ffnDim]
-> Tensor device dtype '[seqLen, batchSize, embedDim])
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear ffnDim embedDim dtype device
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
-> Tensor device dtype '[seqLen, batchSize, embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear ffnDim embedDim dtype device
linear1
(Tensor device dtype '[seqLen, batchSize, ffnDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim]))
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Dropout
-> Bool
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout0 Bool
train
(Tensor device dtype '[seqLen, batchSize, ffnDim]
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim]))
-> (Tensor device dtype '[seqLen, batchSize, embedDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim])
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor device dtype '[seqLen, batchSize, ffnDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
relu
(Tensor device dtype '[seqLen, batchSize, ffnDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim])
-> (Tensor device dtype '[seqLen, batchSize, embedDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim])
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Linear embedDim ffnDim dtype device
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> Tensor device dtype '[seqLen, batchSize, ffnDim]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim ffnDim dtype device
linear0
(Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim]))
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
-> IO (Tensor device dtype '[seqLen, batchSize, ffnDim])
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Tensor device dtype '[seqLen, batchSize, embedDim]
x
instance
( All KnownNat '[embedDim, ffnDim],
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(TransformerMLPSpec embedDim ffnDim dtype device)
(TransformerMLP embedDim ffnDim dtype device)
where
sample :: TransformerMLPSpec embedDim ffnDim dtype device
-> IO (TransformerMLP embedDim ffnDim dtype device)
sample TransformerMLPSpec {Double
DropoutSpec
dropout0Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
dropout1Spec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> DropoutSpec
epsSpec :: forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerMLPSpec embedDim ffnDim dtype device -> Double
dropout0Spec :: DropoutSpec
dropout1Spec :: DropoutSpec
epsSpec :: Double
..} =
Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
forall (embedDim :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
TransformerMLP
(Linear embedDim ffnDim dtype device
-> Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
-> IO (Linear embedDim ffnDim dtype device)
-> IO
(Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> LinearSpec embedDim ffnDim dtype device
-> IO (Linear embedDim ffnDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim ffnDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
IO
(Linear ffnDim embedDim dtype device
-> Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
-> IO (Linear ffnDim embedDim dtype device)
-> IO
(Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec ffnDim embedDim dtype device
-> IO (Linear ffnDim embedDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec ffnDim embedDim dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec
IO
(Dropout
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
-> IO Dropout
-> IO
(Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout0Spec
IO
(Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
-> IO Dropout
-> IO
(LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
dropout1Spec
IO
(LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device)
-> IO (LayerNorm '[embedDim] dtype device)
-> IO (TransformerMLP embedDim ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LayerNormSpec '[embedDim] dtype device
-> IO (LayerNorm '[embedDim] dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (Double -> LayerNormSpec '[embedDim] dtype device
forall (normalizedShape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec)
data
TransformerLayerSpec
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLayerSpec ::
forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
{ forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec :: MultiheadAttentionSpec embedDim kEmbedDim vEmbedDim numHeads dtype device,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
attnDropoutSpec :: DropoutSpec,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
epsSpec' :: Double,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
} ->
TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
deriving (Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
[TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
(Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS)
-> (TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String)
-> ([TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS)
-> Show
(TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
showsPrec :: Int
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
show :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
showList :: [TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
Show, TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
(TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool)
-> (TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool)
-> Eq
(TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
== :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
$c/= :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
/= :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
Eq)
data
TransformerLayer
(embedDim :: Nat)
(kEmbedDim :: Nat)
(vEmbedDim :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLayer ::
forall embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device.
{
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha :: MultiheadAttention embedDim kEmbedDim vEmbedDim numHeads dtype device,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_attnDropout :: Dropout,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_ln :: LayerNorm '[embedDim] dtype device,
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
} ->
TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
deriving (Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
[TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
(Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS)
-> (TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String)
-> ([TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS)
-> Show
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
showsPrec :: Int
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> ShowS
$cshow :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
show :: TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> String
$cshowList :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
[TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
showList :: [TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device]
-> ShowS
Show, (forall x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x)
-> (forall x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> Generic
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
forall x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
from :: forall x.
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
$cto :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)) x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
to :: forall x.
Rep
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
x
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
Generic, TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)))
-> (TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> Parameterized
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
flattenParameters :: TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
$creplaceParameters :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
replaceParameters :: TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> HList
(Parameters
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device))
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
Parameterized)
transformerLayer ::
forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat) (seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat) dtype device.
( 1 <= numHeads,
embedDim ~ (headDim * numHeads),
All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen', batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
KnownDType dtype,
dtype ~ SumDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
SumDTypeIsValid device dtype,
KnownDevice device
) =>
TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device ->
Bool ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen]) ->
Maybe (Tensor device 'D.Bool '[batchSize, seqLen]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim]) ->
Tensor device dtype '[batchSize, seqLen', embedDim] ->
Tensor device dtype '[batchSize, seqLen, kEmbedDim] ->
Tensor device dtype '[batchSize, seqLen, vEmbedDim] ->
IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer :: forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
(kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
(seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
KnownDType dtype, dtype ~ SumDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer {LayerNorm '[embedDim] dtype device
Dropout
TransformerMLP embedDim ffnDim dtype device
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_attnDropout :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Dropout
transformerLayer_ln :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> LayerNorm '[embedDim] dtype device
transformerLayer_mlp :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLP embedDim ffnDim dtype device
transformerLayer_mha :: MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_attnDropout :: Dropout
transformerLayer_ln :: LayerNorm '[embedDim] dtype device
transformerLayer_mlp :: TransformerMLP embedDim ffnDim dtype device
..} Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value =
let f :: Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f Tensor device dtype '[batchSize, seqLen', embedDim]
query' =
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (seqLen :: Nat) (seqLen' :: Nat)
(batchSize :: Nat) (headDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype, dtype ~ SumDType dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
multiheadAttention MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
transformerLayer_mha Bool
train Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
attentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
keyRelations Maybe (Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
valueRelations Tensor device dtype '[batchSize, seqLen', embedDim]
query' Tensor device dtype '[batchSize, seqLen, kEmbedDim]
key Tensor device dtype '[batchSize, seqLen, vEmbedDim]
value
IO
(Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
-> ((Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Dropout
-> Bool
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
transformerLayer_attnDropout Bool
train (Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> ((Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Tensor device dtype '[batchSize, seqLen', embedDim],
Tensor device dtype '[batchSize, seqLen', seqLen])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall a b. (a, b) -> a
fst
in (Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> (Tensor
device
(DTypePromotionImpl dtype dtype (CmpDType dtype dtype))
(CheckBroadcast
'[batchSize, seqLen', embedDim]
'[batchSize, seqLen', embedDim]
(ComputeBroadcast
(ReverseImpl '[batchSize, seqLen', embedDim] '[])
(ReverseImpl '[batchSize, seqLen', embedDim] '[])))
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall {device :: (DeviceType, Nat)} {dtype :: DType}
{dtype' :: DType} {m :: Type -> Type} {shape :: [Nat]}
{shape' :: [Nat]} {b}.
(BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid
device (DTypePromotionImpl dtype dtype' (CmpDType dtype dtype')),
Monad m) =>
(Tensor device dtype shape -> m (Tensor device dtype' shape'))
-> (Tensor
device
(DTypePromotionImpl dtype dtype' (CmpDType dtype dtype'))
(CheckBroadcast
shape
shape'
(ComputeBroadcast
(ReverseImpl shape '[]) (ReverseImpl shape' '[])))
-> m b)
-> Tensor device dtype shape
-> m b
residual Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
f (Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall a. a -> IO a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LayerNorm '[embedDim] dtype device
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen', embedDim]
forall f a b. HasForward f a b => f -> a -> b
forward LayerNorm '[embedDim] dtype device
transformerLayer_ln) Tensor device dtype '[batchSize, seqLen', embedDim]
query IO (Tensor device dtype '[batchSize, seqLen', embedDim])
-> (Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
forall (embedDim :: Nat) (ffnDim :: Nat) (seqLen :: Nat)
(batchSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
(BasicArithmeticDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device dtype,
KnownNat embedDim,
IsSuffixOf '[embedDim] '[seqLen, batchSize, embedDim]) =>
TransformerMLP embedDim ffnDim dtype device
-> Bool
-> Tensor device dtype '[seqLen, batchSize, embedDim]
-> IO (Tensor device dtype '[seqLen, batchSize, embedDim])
transformerMLP TransformerMLP embedDim ffnDim dtype device
transformerLayer_mlp Bool
train
instance
( All KnownNat '[embedDim, kEmbedDim, vEmbedDim, numHeads, ffnDim],
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
A.Randomizable
(TransformerLayerSpec embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
(TransformerLayer embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
where
sample :: TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> IO
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
sample TransformerLayerSpec {Double
DropoutSpec
TransformerMLPSpec embedDim ffnDim dtype device
MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
attnDropoutSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> DropoutSpec
epsSpec' :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Double
mlpSpec :: forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
TransformerLayerSpec
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> TransformerMLPSpec embedDim ffnDim dtype device
mhaSpec :: MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
attnDropoutSpec :: DropoutSpec
epsSpec' :: Double
mlpSpec :: TransformerMLPSpec embedDim ffnDim dtype device
..} =
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
forall (embedDim :: Nat) (kEmbedDim :: Nat) (vEmbedDim :: Nat)
(numHeads :: Nat) (ffnDim :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
TransformerLayer
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
-> IO
(Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
-> IO
(MultiheadAttention
embedDim kEmbedDim vEmbedDim numHeads dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample MultiheadAttentionSpec
embedDim kEmbedDim vEmbedDim numHeads dtype device
mhaSpec
IO
(Dropout
-> LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO Dropout
-> IO
(LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
attnDropoutSpec
IO
(LayerNorm '[embedDim] dtype device
-> TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO (LayerNorm '[embedDim] dtype device)
-> IO
(TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LayerNormSpec '[embedDim] dtype device
-> IO (LayerNorm '[embedDim] dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (Double -> LayerNormSpec '[embedDim] dtype device
forall (normalizedShape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Double -> LayerNormSpec normalizedShape dtype device
LayerNormSpec Double
epsSpec')
IO
(TransformerMLP embedDim ffnDim dtype device
-> TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
-> IO (TransformerMLP embedDim ffnDim dtype device)
-> IO
(TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> TransformerMLPSpec embedDim ffnDim dtype device
-> IO (TransformerMLP embedDim ffnDim dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample TransformerMLPSpec embedDim ffnDim dtype device
mlpSpec
data
TransformerLMSpec
(numAttnLayers :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(paddingIdx :: Nat)
(numEmbeds :: Nat)
(embedDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLMSpec ::
forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
{
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> DropoutSpec
lmDropoutSpec :: DropoutSpec,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec :: TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device
} ->
TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
deriving (Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
[TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
(Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS)
-> (TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String)
-> ([TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS)
-> Show
(TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
showsPrec :: Int
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> ShowS
$cshow :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
show :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> String
$cshowList :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
[TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
showList :: [TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device]
-> ShowS
Show, TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
(TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool)
-> (TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool)
-> Eq
(TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
== :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
$c/= :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
/= :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
Eq)
data
TransformerLM
(numAttnLayers :: Nat)
(numHeads :: Nat)
(ffnDim :: Nat)
(paddingIdx :: Nat)
(numEmbeds :: Nat)
(embedDim :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
TransformerLM ::
forall numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device.
{
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding :: Embedding ('Just paddingIdx) numEmbeds embedDim 'Learned dtype device,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Dropout
tDropout :: Dropout,
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers :: HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device)),
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Linear embedDim numEmbeds dtype device
tProj :: Linear embedDim numEmbeds dtype device
} ->
TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device
deriving ((forall x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x)
-> (forall x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
-> Generic
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
forall x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
forall x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
from :: forall x.
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
$cto :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)) x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
to :: forall x.
Rep
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
x
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
Generic)
deriving instance
( Show
( HList
( HReplicateR
numAttnLayers
( TransformerLayer
embedDim
embedDim
embedDim
numHeads
ffnDim
dtype
device
)
)
)
) =>
Show (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
instance
( layers
~ ( HReplicateR
numAttnLayers
( TransformerLayer
embedDim
embedDim
embedDim
numHeads
ffnDim
dtype
device
)
),
Parameterized
( HList
layers
),
HAppendFD
(Parameters (HList layers))
'[ Parameter device dtype '[numEmbeds, embedDim],
Parameter device dtype '[numEmbeds]
]
( Parameters (HList layers)
++ '[ Parameter device dtype '[numEmbeds, embedDim],
Parameter device dtype '[numEmbeds]
]
)
) =>
Parameterized (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
data
FoldLayers
(batchSize :: Nat)
(seqLen :: Nat)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat)) = FoldLayers
{
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
flTrain :: Bool,
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]),
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'D.Bool '[batchSize, seqLen])
}
instance
( 1 <= numHeads,
embedDim ~ (headDim * numHeads),
All KnownNat '[embedDim, numHeads, seqLen, batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen, embedDim],
KnownDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
dtype ~ SumDType dtype,
SumDTypeIsValid device dtype,
KnownDevice device
) =>
Apply'
(FoldLayers batchSize seqLen dtype device)
( TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device,
IO (Tensor device dtype '[batchSize, seqLen, embedDim])
)
(IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
where
apply' :: FoldLayers batchSize seqLen dtype device
-> (TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device,
IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
apply' FoldLayers {Bool
Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
Maybe (Tensor device 'Bool '[batchSize, seqLen])
flTrain :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device -> Bool
flAttentionMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flKeyPaddingMask :: forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
FoldLayers batchSize seqLen dtype device
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
flTrain :: Bool
flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flKeyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
..} (TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device
layer, IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx) = IO (Tensor device dtype '[batchSize, seqLen, embedDim])
mx IO (Tensor device dtype '[batchSize, seqLen, embedDim])
-> (Tensor device dtype '[batchSize, seqLen, embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim]))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Tensor device dtype '[batchSize, seqLen, embedDim]
x -> TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall (numHeads :: Nat) (ffnDim :: Nat) (embedDim :: Nat)
(kEmbedDim :: Nat) (vEmbedDim :: Nat) (headDim :: Nat)
(seqLen :: Nat) (seqLen' :: Nat) (batchSize :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
(1 <= numHeads, embedDim ~ (headDim * numHeads),
All
KnownNat
'[embedDim, kEmbedDim, vEmbedDim, numHeads, seqLen, seqLen',
batchSize, headDim],
IsSuffixOf '[embedDim] '[batchSize, seqLen', embedDim],
KnownDType dtype, dtype ~ SumDType dtype,
StandardFloatingPointDTypeValidation device dtype,
MatMulDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype,
SumDTypeIsValid device dtype, KnownDevice device) =>
TransformerLayer
embedDim kEmbedDim vEmbedDim numHeads ffnDim dtype device
-> Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen', seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Maybe
(Tensor device dtype '[batchSize, seqLen', seqLen, headDim])
-> Tensor device dtype '[batchSize, seqLen', embedDim]
-> Tensor device dtype '[batchSize, seqLen, kEmbedDim]
-> Tensor device dtype '[batchSize, seqLen, vEmbedDim]
-> IO (Tensor device dtype '[batchSize, seqLen', embedDim])
transformerLayer TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device
layer Bool
flTrain Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
flAttentionMask Maybe (Tensor device 'Bool '[batchSize, seqLen])
flKeyPaddingMask Maybe (Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
forall a. Maybe a
Nothing Maybe (Tensor device dtype '[batchSize, seqLen, seqLen, headDim])
forall a. Maybe a
Nothing Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
x
transformerLM ::
forall
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
seqLen
batchSize
dtype
device.
( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
paddingIdx + 1 <= numEmbeds,
1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'D.Int64,
KnownDType dtype,
KnownDevice device
) =>
TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device ->
Bool ->
Tensor device 'D.Int64 '[batchSize, seqLen] ->
IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
(paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
KnownDevice device) =>
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM {HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
Embedding 'Nothing 2048 embedDim 'Constant dtype device
Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
Linear embedDim numEmbeds dtype device
Dropout
tEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tPosEmbedding :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
tDropout :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Dropout
tLayers :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tProj :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Linear embedDim numEmbeds dtype device
tEmbedding :: Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tPosEmbedding :: Embedding 'Nothing 2048 embedDim 'Constant dtype device
tDropout :: Dropout
tLayers :: HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tProj :: Linear embedDim numEmbeds dtype device
..} Bool
train Tensor device 'Int64 '[batchSize, seqLen]
xTokens = do
let x :: Tensor device dtype '[batchSize, seqLen, embedDim]
x = Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
tEmbedding Tensor device 'Int64 '[batchSize, seqLen]
xTokens
positions :: Tensor device dtype '[batchSize, seqLen, embedDim]
positions =
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', shape' ~ Broadcast shape shape') =>
Bool -> Tensor device dtype shape -> Tensor device dtype shape'
expand @'[batchSize, seqLen, embedDim] Bool
True
(Tensor
device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim])
-> Tensor device dtype '[batchSize, seqLen, embedDim])
-> (Int
-> Tensor
device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim]))
-> Int
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Tensor device 'Int64 '[seqLen]
-> Tensor
device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim])
forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed Embedding 'Nothing 2048 embedDim 'Constant dtype device
tPosEmbedding
(Tensor device 'Int64 '[seqLen]
-> Tensor
device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim]))
-> (Int -> Tensor device 'Int64 '[seqLen])
-> Int
-> Tensor
device dtype (ReverseImpl (ReverseImpl '[seqLen] '[]) '[embedDim])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Int64
(Tensor device 'Float '[seqLen] -> Tensor device 'Int64 '[seqLen])
-> (Int -> Tensor device 'Float '[seqLen])
-> Int
-> Tensor device 'Int64 '[seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @seqLen (Int
0 :: Int)
(Int -> Tensor device dtype '[batchSize, seqLen, embedDim])
-> Int -> Tensor device dtype '[batchSize, seqLen, embedDim]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(seqLen - 1)
Tensor device dtype '[batchSize, seqLen, embedDim]
x' <- Dropout
-> Bool
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
tDropout Bool
train (Tensor device dtype '[batchSize, seqLen, embedDim]
x Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, embedDim]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
`add` Tensor device dtype '[batchSize, seqLen, embedDim]
positions)
let attentionMask :: Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask =
forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @0
(Tensor device 'Bool '[seqLen, seqLen]
-> Tensor device 'Bool '[1, seqLen, seqLen])
-> (Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[seqLen, seqLen])
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[1, seqLen, seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType @D.Bool
(Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[seqLen, seqLen])
-> (Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Int8 '[seqLen, seqLen])
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[seqLen, seqLen]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Int8 '[seqLen, seqLen]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(shape ~ MatrixOrMatrixBatch shape) =>
Int -> Tensor device dtype shape -> Tensor device dtype shape
triu Int
1
(Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[1, seqLen, seqLen])
-> Tensor device 'Int8 '[seqLen, seqLen]
-> Tensor device 'Bool '[1, seqLen, seqLen]
forall a b. (a -> b) -> a -> b
$ forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[seqLen, seqLen] @D.Int8 @device
attentionMask' :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' =
Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
forall a. a -> Maybe a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]))
-> (Tensor device dtype '[batchSize, seqLen, seqLen]
-> Tensor device dtype '[batchSize, seqLen, seqLen])
-> Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tensor device 'Bool '[1, seqLen, seqLen]
-> Double
-> Tensor device dtype '[batchSize, seqLen, seqLen]
-> Tensor device dtype '[batchSize, seqLen, seqLen]
forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[1, seqLen, seqLen]
attentionMask (-Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
0 :: Double) (Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]))
-> Tensor device dtype '[batchSize, seqLen, seqLen]
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
forall a b. (a -> b) -> a -> b
$
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[batchSize, seqLen, seqLen] @dtype @device
let keyPaddingMask :: Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask = Tensor device 'Bool '[batchSize, seqLen]
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
forall a. a -> Maybe a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Tensor device 'Bool '[batchSize, seqLen]
-> Maybe (Tensor device 'Bool '[batchSize, seqLen]))
-> Tensor device 'Bool '[batchSize, seqLen]
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
forall a b. (a -> b) -> a -> b
$ Tensor device 'Int64 '[batchSize, seqLen]
xTokens Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device 'Int64 '[]
-> Tensor device 'Bool '[batchSize, seqLen]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (device :: (DeviceType, Nat)).
(shape'' ~ Broadcast shape shape',
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device 'Bool shape''
==. (Integer -> Tensor device 'Int64 '[]
forall a. Num a => Integer -> a
fromInteger (Integer -> Tensor device 'Int64 '[])
-> (Proxy paddingIdx -> Integer)
-> Proxy paddingIdx
-> Tensor device 'Int64 '[]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy paddingIdx -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy paddingIdx -> Tensor device 'Int64 '[])
-> Proxy paddingIdx -> Tensor device 'Int64 '[]
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @paddingIdx :: Tensor device 'D.Int64 '[])
Tensor device dtype '[batchSize, seqLen, embedDim]
y <- FoldLayers batchSize seqLen dtype device
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> IO (Tensor device dtype '[batchSize, seqLen, embedDim])
forall {k} {k1} (m :: k -> Type) f acc (xs :: [k1]) (res :: k).
HFoldrM m f acc xs res =>
f -> acc -> HList xs -> m res
hfoldrM (Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> FoldLayers batchSize seqLen dtype device
forall (batchSize :: Nat) (seqLen :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
Bool
-> Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
-> Maybe (Tensor device 'Bool '[batchSize, seqLen])
-> FoldLayers batchSize seqLen dtype device
FoldLayers Bool
train Maybe (Tensor device dtype '[batchSize, seqLen, seqLen])
attentionMask' Maybe (Tensor device 'Bool '[batchSize, seqLen])
keyPaddingMask) Tensor device dtype '[batchSize, seqLen, embedDim]
x' HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
tLayers
Tensor device dtype '[batchSize, seqLen, numEmbeds]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall a. a -> IO a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Tensor device dtype '[batchSize, seqLen, numEmbeds]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds]))
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall a b. (a -> b) -> a -> b
$ Linear embedDim numEmbeds dtype device
-> Tensor device dtype '[batchSize, seqLen, embedDim]
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forall f a b. HasForward f a b => f -> a -> b
forward Linear embedDim numEmbeds dtype device
tProj Tensor device dtype '[batchSize, seqLen, embedDim]
y
instance
( All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
paddingIdx + 1 <= numEmbeds,
1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'D.Int64,
KnownDType dtype,
KnownDevice device
) =>
HasForward (TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device) (Tensor device 'D.Int64 '[batchSize, seqLen]) (Tensor device dtype '[batchSize, seqLen, numEmbeds])
where
forward :: TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forward TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Tensor device 'Int64 '[batchSize, seqLen]
input = IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forall a. IO a -> a
unsafePerformIO (IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
-> Tensor device dtype '[batchSize, seqLen, numEmbeds]
forall a b. (a -> b) -> a -> b
$ TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
(paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
KnownDevice device) =>
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Bool
False Tensor device 'Int64 '[batchSize, seqLen]
input
forwardStoch :: TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forwardStoch TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Tensor device 'Int64 '[batchSize, seqLen]
input = TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(seqLen :: Nat) (batchSize :: Nat) (dtype :: DType)
(device :: (DeviceType, Nat)).
(All KnownNat '[paddingIdx, embedDim, seqLen, batchSize],
(paddingIdx + 1) <= numEmbeds, 1 <= seqLen,
HFoldrM
IO
(FoldLayers batchSize seqLen dtype device)
(Tensor device dtype '[batchSize, seqLen, embedDim])
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
(Tensor device dtype '[batchSize, seqLen, embedDim]),
BasicArithmeticDTypeIsValid device dtype,
ComparisonDTypeIsValid device dtype,
ComparisonDTypeIsValid device 'Int64, KnownDType dtype,
KnownDevice device) =>
TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> Bool
-> Tensor device 'Int64 '[batchSize, seqLen]
-> IO (Tensor device dtype '[batchSize, seqLen, numEmbeds])
transformerLM TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
model Bool
True Tensor device 'Int64 '[batchSize, seqLen]
input
sinusoidal ::
forall numEmbeds embedDim device.
( All KnownNat '[numEmbeds, embedDim],
1 <= numEmbeds,
1 <= Div embedDim 2,
(Div embedDim 2 * 2) ~ embedDim,
StandardFloatingPointDTypeValidation device 'D.Float,
BasicArithmeticDTypeIsValid device 'D.Float,
KnownDevice device
) =>
Tensor device 'D.Float '[numEmbeds, embedDim]
sinusoidal :: forall (numEmbeds :: Nat) (embedDim :: Nat)
(device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
StandardFloatingPointDTypeValidation device 'Float,
BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal =
let positions :: Tensor device 'Float '[numEmbeds, 1]
positions =
forall (dim :: Nat) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(KnownNat dim, shape' ~ Unsqueeze shape dim) =>
Tensor device dtype shape -> Tensor device dtype shape'
unsqueeze @1
(Tensor device 'Float '[numEmbeds]
-> Tensor device 'Float '[numEmbeds, 1])
-> (Int -> Tensor device 'Float '[numEmbeds])
-> Int
-> Tensor device 'Float '[numEmbeds, 1]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @numEmbeds (Int
0 :: Int)
(Int -> Tensor device 'Float '[numEmbeds, 1])
-> Int -> Tensor device 'Float '[numEmbeds, 1]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @(numEmbeds - 1)
scalingFactors :: Tensor device 'Float '[Div embedDim 2]
scalingFactors =
Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[Div embedDim 2]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
StandardFloatingPointDTypeValidation device dtype =>
Tensor device dtype shape -> Tensor device dtype shape
exp
(Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[Div embedDim 2])
-> (Int -> Tensor device 'Float '[Div embedDim 2])
-> Int
-> Tensor device 'Float '[Div embedDim 2]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double
-> Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[Div embedDim 2]
forall a (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Scalar a =>
a -> Tensor device dtype shape -> Tensor device dtype shape
mulScalar (- Double -> Double
forall a. Floating a => a -> a
log (Double
10000 :: Double) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Integer -> Double
forall a. Num a => Integer -> a
fromInteger (Integer -> Double)
-> (Proxy (Div embedDim 2) -> Integer)
-> Proxy (Div embedDim 2)
-> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proxy (Div embedDim 2) -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal (Proxy (Div embedDim 2) -> Double)
-> Proxy (Div embedDim 2) -> Double
forall a b. (a -> b) -> a -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @(Div embedDim 2)))
(Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[Div embedDim 2])
-> (Int -> Tensor device 'Float '[Div embedDim 2])
-> Int
-> Tensor device 'Float '[Div embedDim 2]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (steps :: Nat) (device :: (DeviceType, Nat)) start end.
(Scalar start, Scalar end, KnownNat steps,
TensorOptions '[steps] 'Float device) =>
start -> end -> Tensor device 'Float '[steps]
linspace @(Div embedDim 2) (Int
0 :: Int)
(Int -> Tensor device 'Float '[Div embedDim 2])
-> Int -> Tensor device 'Float '[Div embedDim 2]
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Int
natValI @((Div embedDim 2) - 1)
radians :: Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians = Tensor device 'Float '[numEmbeds, 1]
-> Tensor device 'Float '[Div embedDim 2]
-> Tensor device 'Float '[numEmbeds, Div embedDim 2]
forall (shape'' :: [Nat]) (shape :: [Nat]) (shape' :: [Nat])
(dtype :: DType) (dtype' :: DType) (dtype'' :: DType)
(device :: (DeviceType, Nat)).
(dtype'' ~ DTypePromotion dtype dtype',
shape'' ~ Broadcast shape shape',
BasicArithmeticDTypeIsValid device dtype,
BasicArithmeticDTypeIsValid device dtype',
BasicArithmeticDTypeIsValid device dtype'') =>
Tensor device dtype shape
-> Tensor device dtype' shape' -> Tensor device dtype'' shape''
mul Tensor device 'Float '[numEmbeds, 1]
positions Tensor device 'Float '[Div embedDim 2]
scalingFactors
weights :: Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights = forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [Type]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Stack dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
stack @2 (Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> Tensor device 'Float '[numEmbeds, Div embedDim 2]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
sin Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> HList '[Tensor device 'Float '[numEmbeds, Div embedDim 2]]
-> HList
'[Tensor device 'Float '[numEmbeds, Div embedDim 2],
Tensor device 'Float '[numEmbeds, Div embedDim 2]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> Tensor device 'Float '[numEmbeds, Div embedDim 2]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) t.
(StandardFloatingPointDTypeValidation device dtype,
IsUnnamed t device dtype shape) =>
t -> t
cos Tensor device 'Float '[numEmbeds, Div embedDim 2]
radians Tensor device 'Float '[numEmbeds, Div embedDim 2]
-> HList '[]
-> HList '[Tensor device 'Float '[numEmbeds, Div embedDim 2]]
forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil)
in Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
-> Tensor device 'Float '[numEmbeds, embedDim]
forall (shape' :: [Nat]) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownShape shape', Numel shape ~ Numel shape') =>
Tensor device dtype shape -> Tensor device dtype shape'
reshape Tensor device 'Float '[numEmbeds, Div embedDim 2, 2]
weights
instance
( paddingIdx <= numEmbeds,
1 <= numEmbeds - paddingIdx,
1 <= Div embedDim 2,
(((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
(Div embedDim 2 * 2) ~ embedDim,
All KnownNat '[ffnDim, paddingIdx, numEmbeds, embedDim],
HReplicate numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device),
A.Randomizable
(HList (HReplicateR numAttnLayers (TransformerLayerSpec embedDim embedDim embedDim numHeads ffnDim dtype device)))
(HList (HReplicateR numAttnLayers (TransformerLayer embedDim embedDim embedDim numHeads ffnDim dtype device))),
KnownDType dtype,
RandDTypeIsValid device dtype,
StandardFloatingPointDTypeValidation device 'D.Float,
BasicArithmeticDTypeIsValid device 'D.Float,
KnownDevice device
) =>
A.Randomizable
(TransformerLMSpec numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
(TransformerLM numAttnLayers numHeads ffnDim paddingIdx numEmbeds embedDim dtype device)
where
sample :: TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> IO
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
sample TransformerLMSpec {DropoutSpec
TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> DropoutSpec
lmLayerSpec :: forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
TransformerLMSpec
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
-> TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmDropoutSpec :: DropoutSpec
lmLayerSpec :: TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
..} =
Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
forall (numAttnLayers :: Nat) (numHeads :: Nat) (ffnDim :: Nat)
(paddingIdx :: Nat) (numEmbeds :: Nat) (embedDim :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device
TransformerLM
(Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
-> IO
(Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device)
-> IO
(Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> EmbeddingSpec
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device
-> IO
(Embedding
('Just paddingIdx) numEmbeds embedDim 'Learned dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx))
IO
(Embedding 'Nothing 2048 embedDim 'Constant dtype device
-> Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
-> IO (Embedding 'Nothing 2048 embedDim 'Constant dtype device)
-> IO
(Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> EmbeddingSpec 'Nothing 2048 embedDim 'Constant dtype device
-> IO (Embedding 'Nothing 2048 embedDim 'Constant dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> EmbeddingSpec
paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbeddingSpec @'Nothing (Tensor device 'Float '[2048, embedDim]
-> Tensor device dtype '[2048, embedDim]
forall (dtype' :: DType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape :: [Nat]) t t'.
(KnownDType dtype', IsUnnamed t device dtype shape, Unnamed t',
t' ~ ReplaceDType'' t dtype') =>
t -> t'
Torch.Typed.Tensor.toDType Tensor device 'Float '[2048, embedDim]
forall (numEmbeds :: Nat) (embedDim :: Nat)
(device :: (DeviceType, Nat)).
(All KnownNat '[numEmbeds, embedDim], 1 <= numEmbeds,
1 <= Div embedDim 2, (Div embedDim 2 * 2) ~ embedDim,
StandardFloatingPointDTypeValidation device 'Float,
BasicArithmeticDTypeIsValid device 'Float, KnownDevice device) =>
Tensor device 'Float '[numEmbeds, embedDim]
sinusoidal))
IO
(Dropout
-> HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
-> IO Dropout
-> IO
(HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> DropoutSpec -> IO Dropout
forall spec f. Randomizable spec f => spec -> IO f
A.sample DropoutSpec
lmDropoutSpec
IO
(HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
-> IO
(HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device)))
-> IO
(Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> HList
(HReplicateR
numAttnLayers
(TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device))
-> IO
(HList
(HReplicateR
numAttnLayers
(TransformerLayer
embedDim embedDim embedDim numHeads ffnDim dtype device)))
forall spec f. Randomizable spec f => spec -> IO f
A.sample (forall (n :: Nat) e. HReplicate n e => e -> HList (HReplicateR n e)
hreplicate @numAttnLayers TransformerLayerSpec
embedDim embedDim embedDim numHeads ffnDim dtype device
lmLayerSpec)
IO
(Linear embedDim numEmbeds dtype device
-> TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
-> IO (Linear embedDim numEmbeds dtype device)
-> IO
(TransformerLM
numAttnLayers
numHeads
ffnDim
paddingIdx
numEmbeds
embedDim
dtype
device)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> LinearSpec embedDim numEmbeds dtype device
-> IO (Linear embedDim numEmbeds dtype device)
forall spec f. Randomizable spec f => spec -> IO f
A.sample LinearSpec embedDim numEmbeds dtype device
forall (inputFeatures :: Nat) (outputFeatures :: Nat)
(dtype :: DType) (device :: (DeviceType, Nat)).
LinearSpec inputFeatures outputFeatures dtype device
LinearSpec