{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Torch.Typed.NN.Sparse where

import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import qualified Torch.DType as D
import qualified Torch.Device as D
import Torch.HList
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Auxiliary
import Torch.Typed.Factories
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data EmbeddingType = Constant | Learned deriving (Int -> EmbeddingType -> ShowS
[EmbeddingType] -> ShowS
EmbeddingType -> String
(Int -> EmbeddingType -> ShowS)
-> (EmbeddingType -> String)
-> ([EmbeddingType] -> ShowS)
-> Show EmbeddingType
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> EmbeddingType -> ShowS
showsPrec :: Int -> EmbeddingType -> ShowS
$cshow :: EmbeddingType -> String
show :: EmbeddingType -> String
$cshowList :: [EmbeddingType] -> ShowS
showList :: [EmbeddingType] -> ShowS
Show, (forall x. EmbeddingType -> Rep EmbeddingType x)
-> (forall x. Rep EmbeddingType x -> EmbeddingType)
-> Generic EmbeddingType
forall x. Rep EmbeddingType x -> EmbeddingType
forall x. EmbeddingType -> Rep EmbeddingType x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. EmbeddingType -> Rep EmbeddingType x
from :: forall x. EmbeddingType -> Rep EmbeddingType x
$cto :: forall x. Rep EmbeddingType x -> EmbeddingType
to :: forall x. Rep EmbeddingType x -> EmbeddingType
Generic)

data
  EmbeddingSpec
    (paddingIdx :: Maybe Nat)
    (numEmbeds :: Nat)
    (embedSize :: Nat)
    (embeddingType :: EmbeddingType)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConstEmbeddingSpec ::
    forall paddingIdx numEmbeds embedSize dtype device.
    Tensor device dtype '[numEmbeds, embedSize] ->
    EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
  LearnedEmbeddingWithRandomInitSpec ::
    forall paddingIdx numEmbeds embedSize dtype device.
    EmbeddingSpec
      paddingIdx
      numEmbeds
      embedSize
      'Learned
      dtype
      device
  LearnedEmbeddingWithCustomInitSpec ::
    forall paddingIdx numEmbeds embedSize dtype device.
    Tensor device dtype '[numEmbeds, embedSize] ->
    EmbeddingSpec paddingIdx numEmbeds embedSize 'Learned dtype device

deriving instance Show (EmbeddingSpec paddingIdx numEmbeds embedSize embeddingType dtype device)

data
  Embedding
    (paddingIdx :: Maybe Nat)
    (numEmbeds :: Nat)
    (embedSize :: Nat)
    (embeddingType :: EmbeddingType)
    (dtype :: D.DType)
    (device :: (D.DeviceType, Nat))
  where
  ConstEmbedding ::
    forall paddingIdx numEmbeds embedSize dtype device.
    --  . (PaddingIdxCheck paddingIdx numEmbeds)
    {forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]} ->
    Embedding
      paddingIdx
      numEmbeds
      embedSize
      'Constant
      dtype
      device
  LearnedEmbedding ::
    forall paddingIdx numEmbeds embedSize dtype device.
    --  . (PaddingIdxCheck paddingIdx numEmbeds)
    {forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]} ->
    Embedding
      paddingIdx
      numEmbeds
      embedSize
      'Learned
      dtype
      device

deriving instance Show (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device)

instance Generic (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) where
  type
    Rep (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) =
      Rec0 (Tensor device dtype '[numEmbeds, embedSize])
  from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Rep
     (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
from (ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
..}) = Tensor device dtype '[numEmbeds, embedSize]
-> K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
forall k i c (p :: k). c -> K1 i c p
K1 Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
  to :: forall x.
Rep
  (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
to = Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding (Tensor device dtype '[numEmbeds, embedSize]
 -> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
-> (K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
    -> Tensor device dtype '[numEmbeds, embedSize])
-> K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 R (Tensor device dtype '[numEmbeds, embedSize]) x
-> Tensor device dtype '[numEmbeds, embedSize]
forall k i c (p :: k). K1 i c p -> c
unK1

instance Generic (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) where
  type
    Rep (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) =
      Rec0 (Parameter device dtype '[numEmbeds, embedSize])
  from :: forall x.
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Rep
     (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
from (LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
..}) = Parameter device dtype '[numEmbeds, embedSize]
-> K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
forall k i c (p :: k). c -> K1 i c p
K1 Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights
  to :: forall x.
Rep
  (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device) x
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
to = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
 -> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device)
-> (K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
    -> Parameter device dtype '[numEmbeds, embedSize])
-> K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 R (Parameter device dtype '[numEmbeds, embedSize]) x
-> Parameter device dtype '[numEmbeds, embedSize]
forall k i c (p :: k). K1 i c p -> c
unK1

instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)

instance Parameterized (Embedding paddingIdx numEmbeds embedSize 'Learned dtype device)

embed ::
  forall paddingIdx shape numEmbeds embedSize embeddingType dtype device shape'.
  ( KnownMaybeNat paddingIdx,
    PaddingIdxCheck paddingIdx numEmbeds,
    shape' ~ Reverse (embedSize ': (Reverse shape))
  ) =>
  Embedding paddingIdx numEmbeds embedSize embeddingType dtype device ->
  Tensor device 'D.Int64 shape ->
  Tensor device dtype shape'
embed :: forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed ConstEmbedding {Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights :: Tensor device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
  forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
    Bool
False
    Bool
False
    Tensor device dtype '[numEmbeds, embedSize]
constEmbedWeights
    Tensor device 'Int64 shape
input
embed LearnedEmbedding {Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
-> Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights :: Parameter device dtype '[numEmbeds, embedSize]
..} Tensor device 'Int64 shape
input =
  forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedDim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds) =>
Bool
-> Bool
-> Tensor device dtype '[numEmbeds, embedDim]
-> Tensor device 'Int64 shape
-> Tensor device dtype (Reverse (embedDim : Reverse shape))
embedding @paddingIdx
    Bool
False
    Bool
False
    (Parameter device dtype '[numEmbeds, embedSize]
-> Tensor device dtype '[numEmbeds, embedSize]
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Parameter device dtype shape -> Tensor device dtype shape
toDependent Parameter device dtype '[numEmbeds, embedSize]
learnedEmbedWeights)
    Tensor device 'Int64 shape
input

instance
  ( KnownMaybeNat paddingIdx,
    PaddingIdxCheck paddingIdx numEmbeds,
    shape' ~ Reverse (embedSize ': (Reverse shape))
  ) =>
  HasForward (Embedding paddingIdx numEmbeds embedSize embeddingType dtype device) (Tensor device 'D.Int64 shape) (Tensor device dtype shape')
  where
  forward :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forward = Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forall (paddingIdx :: Maybe Nat) (shape :: [Nat])
       (numEmbeds :: Nat) (embedSize :: Nat)
       (embeddingType :: EmbeddingType) (dtype :: DType)
       (device :: (DeviceType, Nat)) (shape' :: [Nat]).
(KnownMaybeNat paddingIdx, PaddingIdxCheck paddingIdx numEmbeds,
 shape' ~ Reverse (embedSize : Reverse shape)) =>
Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
embed
  forwardStoch :: Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> IO (Tensor device dtype shape')
forwardStoch = (Tensor device dtype shape' -> IO (Tensor device dtype shape')
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor device dtype shape' -> IO (Tensor device dtype shape'))
-> (Tensor device 'Int64 shape -> Tensor device dtype shape')
-> Tensor device 'Int64 shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
.) ((Tensor device 'Int64 shape -> Tensor device dtype shape')
 -> Tensor device 'Int64 shape -> IO (Tensor device dtype shape'))
-> (Embedding
      paddingIdx numEmbeds embedSize embeddingType dtype device
    -> Tensor device 'Int64 shape -> Tensor device dtype shape')
-> Embedding
     paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape
-> IO (Tensor device dtype shape')
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Embedding paddingIdx numEmbeds embedSize embeddingType dtype device
-> Tensor device 'Int64 shape -> Tensor device dtype shape'
forall f a b. HasForward f a b => f -> a -> b
forward

instance
  Randomizable
    (EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device)
    (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
  where
  sample :: EmbeddingSpec paddingIdx numEmbeds embedSize 'Constant dtype device
-> IO
     (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
sample (ConstEmbeddingSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
-> IO
     (Embedding paddingIdx numEmbeds embedSize 'Constant dtype device)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Tensor device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Constant dtype device
ConstEmbedding Tensor device dtype '[numEmbeds, embedSize]
tensor)

instance
  ( KnownNat numEmbeds,
    KnownNat embedSize,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device)
    (Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
  where
  sample :: EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
-> IO
     (Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec 'Nothing numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
 -> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
     (Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
 -> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO (Tensor device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn)
  sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
 -> Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
     (Embedding 'Nothing numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
 -> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Tensor device dtype '[numEmbeds, embedSize])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))

instance
  ( paddingIdx <= numEmbeds,
    1 <= numEmbeds - paddingIdx,
    (((numEmbeds - paddingIdx) - 1) + (1 + paddingIdx)) ~ numEmbeds,
    KnownNat paddingIdx,
    KnownNat numEmbeds,
    KnownNat embedSize,
    KnownDType dtype,
    KnownDevice device,
    RandDTypeIsValid device dtype
  ) =>
  Randomizable
    (EmbeddingSpec ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
    (Embedding ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
  where
  sample :: EmbeddingSpec
  ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
-> IO
     (Embedding
        ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
sample EmbeddingSpec
  ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
LearnedEmbeddingWithRandomInitSpec =
    let mask :: Tensor device 'Bool '[numEmbeds, embedSize]
mask =
          forall (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [*]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
forall {k} (dim :: Nat) (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)) (tensors :: [k]).
(KnownNat dim, '(shape, dtype, device) ~ Cat dim tensors,
 Castable (HList tensors) [ATenTensor]) =>
HList tensors -> Tensor device dtype shape
cat @0
            ( forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[paddingIdx, embedSize] @'D.Bool @device
                Tensor device 'Bool '[paddingIdx, embedSize]
-> HList
     '[Tensor device 'Bool '[1, embedSize],
       Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
-> HList
     '[Tensor device 'Bool '[paddingIdx, embedSize],
       Tensor device 'Bool '[1, embedSize],
       Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
ones @'[1, embedSize] @'D.Bool @device
                Tensor device 'Bool '[1, embedSize]
-> HList
     '[Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
-> HList
     '[Tensor device 'Bool '[1, embedSize],
       Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
TensorOptions shape dtype device =>
Tensor device dtype shape
zeros @'[numEmbeds - paddingIdx - 1, embedSize] @'D.Bool @device
                Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]
-> HList '[]
-> HList
     '[Tensor device 'Bool '[(numEmbeds - paddingIdx) - 1, embedSize]]
forall x (xs :: [*]). x -> HList xs -> HList (x : xs)
:. HList '[]
forall k. HList '[]
HNil
            )
     in Parameter device dtype '[numEmbeds, embedSize]
-> Embedding
     ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
 -> Embedding
      ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
     (Embedding
        ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
 -> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Tensor device 'Bool '[numEmbeds, embedSize]
-> Int
-> Tensor device dtype '[numEmbeds, embedSize]
-> Tensor device dtype '[numEmbeds, embedSize]
forall a (shape :: [Nat]) (shape' :: [Nat]) (shape'' :: [Nat])
       (dtype :: DType) (device :: (DeviceType, Nat)).
(Scalar a, shape'' ~ Broadcast shape shape') =>
Tensor device 'Bool shape'
-> a -> Tensor device dtype shape -> Tensor device dtype shape''
maskedFill Tensor device 'Bool '[numEmbeds, embedSize]
mask (Int
0 :: Int) (Tensor device dtype '[numEmbeds, embedSize]
 -> Tensor device dtype '[numEmbeds, embedSize])
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Tensor device dtype '[numEmbeds, embedSize])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
(TensorOptions shape dtype device,
 RandDTypeIsValid device dtype) =>
IO (Tensor device dtype shape)
randn @'[numEmbeds, embedSize] @dtype @device)))
  sample (LearnedEmbeddingWithCustomInitSpec Tensor device dtype '[numEmbeds, embedSize]
tensor) = Parameter device dtype '[numEmbeds, embedSize]
-> Embedding
     ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device
forall (paddingIdx :: Maybe Nat) (numEmbeds :: Nat)
       (embedSize :: Nat) (dtype :: DType) (device :: (DeviceType, Nat)).
Parameter device dtype '[numEmbeds, embedSize]
-> Embedding paddingIdx numEmbeds embedSize 'Learned dtype device
LearnedEmbedding (Parameter device dtype '[numEmbeds, embedSize]
 -> Embedding
      ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
-> IO (Parameter device dtype '[numEmbeds, embedSize])
-> IO
     (Embedding
        ('Just paddingIdx) numEmbeds embedSize 'Learned dtype device)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Tensor device dtype shape -> IO (Parameter device dtype shape)
makeIndependent (Tensor device dtype '[numEmbeds, embedSize]
 -> IO (Parameter device dtype '[numEmbeds, embedSize]))
-> IO (Tensor device dtype '[numEmbeds, embedSize])
-> IO (Parameter device dtype '[numEmbeds, embedSize])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Tensor device dtype '[numEmbeds, embedSize]
-> IO (Tensor device dtype '[numEmbeds, embedSize])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Tensor device dtype '[numEmbeds, embedSize]
tensor))