{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Torch.Typed.NN.Sparse where
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor
data EmbeddingType = Constant | Learned deriving (Int -> EmbeddingType -> ShowS
[EmbeddingType] -> ShowS
EmbeddingType -> String
(Int -> EmbeddingType -> ShowS)
-> (EmbeddingType -> String)
-> ([EmbeddingType] -> ShowS)
-> Show EmbeddingType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EmbeddingType -> ShowS
showsPrec :: Int -> EmbeddingType -> ShowS
$cshow :: EmbeddingType -> String
show :: EmbeddingType -> String
$cshowList :: [EmbeddingType] -> ShowS
showList :: [EmbeddingType] -> ShowS
Show, (forall x. EmbeddingType -> Rep EmbeddingType x)
-> (forall x. Rep EmbeddingType x -> EmbeddingType)
-> Generic EmbeddingType
forall x. Rep EmbeddingType x -> EmbeddingType
forall x. EmbeddingType -> Rep EmbeddingType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. EmbeddingType -> Rep EmbeddingType x
from :: forall x. EmbeddingType -> Rep EmbeddingType x
$cto :: forall x. Rep EmbeddingType x -> EmbeddingType
to :: forall x. Rep EmbeddingType x -> EmbeddingType
Generic)
data
EmbeddingSpec
(paddingIdx :: Maybe Nat)
(numEmbeds :: Nat)
(embedSize :: Nat)
(embeddingType :: EmbeddingType)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConstEmbeddingSpec ::
forall paddingIdx numEmbeds embedSize dtype device.
Tensor device dtype '[numEmbeds, embedSize] ->
EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
LearnedEmbeddingWithRandomInitSpec ::
forall paddingIdx numEmbeds embedSize dtype device.
EmbeddingSpec
paddingIdx
numEmbeds
embedSize
'Learned
dtype
device
LearnedEmbeddingWithCustomInitSpec ::
forall paddingIdx numEmbeds embedSize dtype device.
Tensor device dtype '[numEmbeds, embedSize] ->
EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device
deriving instance Show (EmbeddingSpec paddingIdx numEmbeds embedSize embeddingType dtype device)
data
Embedding
(paddingIdx :: Maybe Nat)
(numEmbeds :: Nat)
(embedSize :: Nat)
(embeddingType :: EmbeddingType)
(dtype :: D.DType)
(device :: (D.DeviceType, Nat))
where
ConstEmbedding ::
forall paddingIdx numEmbeds embedSize dtype device.
{forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]} ->
Embedding
paddingIdx
numEmbeds
embedSize
'Constant
dtype
device
LearnedEmbedding ::
forall paddingIdx numEmbeds embedSize dtype device.
{forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]} ->
Embedding
paddingIdx
numEmbeds
embedSize
'Learned
dtype
device
deriving instance Show (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device)
instance Generic (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) where
type
Rep (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) =
Rec0 (Tensor device dtype '[numEmbeds, embedSize])
from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Rep
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
from (ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
..}) = Tensor device dtype '[numEmbeds, embedSize]
-> K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
forall k i c (p :: k). c -> K1 i c p
K1 Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
to :: forall x.
Rep
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
to = Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding (Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
-> (K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
-> Tensor device dtype '[numEmbeds, embedSize])
-> K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
-> Tensor device dtype '[numEmbeds, embedSize]
forall k i c (p :: k). K1 i c p -> c
unK1
instance Generic (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) where
type
Rep (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) =
Rec0 (Parameter device dtype '[numEmbeds, embedSize])
from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Rep
(Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
from (LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
..}) = Parameter device dtype '[numEmbeds, embedSize]
-> K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
forall k i c (p :: k). c -> K1 i c p
K1 Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights
to :: forall x.
Rep
(Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
to = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device)
-> (K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
-> Parameter device dtype '[numEmbeds, embedSize])
-> K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
-> Parameter device dtype '[numEmbeds, embedSize]
forall k i c (p :: k). K1 i c p -> c
unK1
instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device)
embed ::
forall paddingIdx shape numEmbeds embedSize embeddingType dtype device shape'.
( KnownMaybeNat paddingIdx,
PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize ': (Reverse shape))
) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device ->
Tensor device 'D.Int64 shape ->
Tensor device dtype shape'
embed :: forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
Bool
False
Bool
False
Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
Tensor device 'Int64 shape
input
embed LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
Bool
False
Bool
False
(Parameter device dtype '[numEmbeds, embedSize]
-> Tensor device dtype '[numEmbeds, embedSize]
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights)
Tensor device 'Int64 shape
input
instance
( KnownMaybeNat paddingIdx,
PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize ': (Reverse shape))
) =>
HasForward (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device) (Tensor device 'D.Int64 shape) (Tensor device dtype shape')
where
forward :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forward = Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
(numEmbeds :: Nat) (embedSize :: Nat)
(embeddingType :: EmbeddingType) (dtype :: DType)
(device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed
forwardStoch :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> IO (Tensor device dtype shape')
forwardStoch = (Tensor device dtype shape' -> IO (Tensor device dtype shape')
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor device dtype shape' -> IO (Tensor device dtype shape'))
-> (Tensor device 'Int64 shape -> Tensor device dtype shape')
-> Tensor device 'Int64 shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Tensor device 'Int64 shape -> Tensor device dtype shape')
-> Tensor device 'Int64 shape -> IO (Tensor device dtype shape'))
-> (Embedding
paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape')
-> Embedding
paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forall f a b. HasForward f a b => f -> a -> b
forward
instance
Randomizable
(EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device)
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
where
sample :: EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
-> IO
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
sample (ConstEmbeddingSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> IO
(Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding Tensor device dtype '[numEmbeds, embedSize]
tensor)
instance
( KnownNat numEmbeds,
KnownNat embedSize,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device)
(Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
where
sample :: EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
-> IO
(Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
-> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
(Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
-> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
(Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Tensor device dtype '[numEmbeds, embedSize])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))
instance
( paddingIdx <= numEmbeds,
1 <= numEmbeds - paddingIdx,
(((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
KnownNat paddingIdx,
KnownNat numEmbeds,
KnownNat embedSize,
KnownDType dtype,
KnownDevice device,
RandDTypeIsValid device dtype
) =>
Randomizable
(EmbeddingSpec ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
(Embedding ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
where
sample :: EmbeddingSpec
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
-> IO
(Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec =
let mask :: Tensor device 'Bool '[numEmbeds, embedSize]
mask =
forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [*]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
cat @0
( forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[paddingIdx, embedSize] @'D.Bool @device
Tensor device 'Bool '[paddingIdx, embedSize]
-> HList
'[Tensor device 'Bool '[1, embedSize],
Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
-> HList
'[Tensor device 'Bool '[paddingIdx, embedSize],
Tensor device 'Bool '[1, embedSize],
Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[1, embedSize] @'D.Bool @device
Tensor device 'Bool '[1, embedSize]
-> HList
'[Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
-> HList
'[Tensor device 'Bool '[1, embedSize],
Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[numEmbeds - paddingIdx - 1, embedSize] @'D.Bool @device
Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]
-> HList '[]
-> HList
'[Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil
)
in Parameter device dtype '[numEmbeds, embedSize]
-> Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
-> Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
(Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Tensor device 'Bool '[numEmbeds, embedSize]
-> Int
-> Tensor device dtype '[numEmbeds, embedSize]
-> Tensor device dtype '[numEmbeds, embedSize]
forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
(dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[numEmbeds, embedSize]
mask (Int
0 :: Int) (Tensor device dtype '[numEmbeds, embedSize]
-> Tensor device dtype '[numEmbeds, embedSize])
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Tensor device dtype '[numEmbeds, embedSize])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn @'[numEmbeds, embedSize] @dtype @device)))
sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
(embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
-> Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
(Embedding
('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
(device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Tensor device dtype '[numEmbeds, embedSize])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))