{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE UndecidableInstances #-} module Torch.Typed.NN.Normalization where import GHC.Generics import GHC.TypeLits import qualified Torch.DType as D import qualified Torch.Device as D import Torch.NN (HasForward (..), Randomizable (..)) import Torch.Typed.Auxiliary import Torch.Typed.Factories import Torch.Typed.Functional import Torch.Typed.Parameter import Torch.Typed.Tensor data LayerNormSpec (normalizedShape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where LayerNormSpec :: forall normalizedShape dtype device. {forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> Double layerNormEpsSpec :: Double} -> LayerNormSpec normalizedShape dtype device deriving (Int -> LayerNormSpec normalizedShape dtype device -> ShowS [LayerNormSpec normalizedShape dtype device] -> ShowS LayerNormSpec normalizedShape dtype device -> String (Int -> LayerNormSpec normalizedShape dtype device -> ShowS) -> (LayerNormSpec normalizedShape dtype device -> String) -> ([LayerNormSpec normalizedShape dtype device] -> ShowS) -> Show (LayerNormSpec normalizedShape dtype device) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNormSpec normalizedShape dtype device -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNormSpec normalizedShape dtype device] -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNormSpec normalizedShape dtype device -> ShowS showsPrec :: Int -> LayerNormSpec normalizedShape dtype device -> ShowS $cshow :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> String show :: LayerNormSpec normalizedShape dtype device -> String $cshowList :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNormSpec normalizedShape dtype device] -> ShowS showList :: [LayerNormSpec normalizedShape dtype device] -> ShowS Show, LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool (LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool) -> (LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool) -> Eq (LayerNormSpec normalizedShape dtype device) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a $c== :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool == :: LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool $c/= :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool /= :: LayerNormSpec normalizedShape dtype device -> LayerNormSpec normalizedShape dtype device -> Bool Eq) data LayerNorm (normalizedShape :: [Nat]) (dtype :: D.DType) (device :: (D.DeviceType, Nat)) where LayerNorm :: { forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormWeight :: Parameter device dtype normalizedShape, forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormBias :: Parameter device dtype normalizedShape, forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Double layerNormEps :: Double } -> LayerNorm normalizedShape dtype device deriving (Int -> LayerNorm normalizedShape dtype device -> ShowS [LayerNorm normalizedShape dtype device] -> ShowS LayerNorm normalizedShape dtype device -> String (Int -> LayerNorm normalizedShape dtype device -> ShowS) -> (LayerNorm normalizedShape dtype device -> String) -> ([LayerNorm normalizedShape dtype device] -> ShowS) -> Show (LayerNorm normalizedShape dtype device) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNorm normalizedShape dtype device -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNorm normalizedShape dtype device] -> ShowS forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> String forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Int -> LayerNorm normalizedShape dtype device -> ShowS showsPrec :: Int -> LayerNorm normalizedShape dtype device -> ShowS $cshow :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> String show :: LayerNorm normalizedShape dtype device -> String $cshowList :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). [LayerNorm normalizedShape dtype device] -> ShowS showList :: [LayerNorm normalizedShape dtype device] -> ShowS Show, (forall x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x) -> (forall x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device) -> Generic (LayerNorm normalizedShape dtype device) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x forall x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device forall x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a $cfrom :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x from :: forall x. LayerNorm normalizedShape dtype device -> Rep (LayerNorm normalizedShape dtype device) x $cto :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)) x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device to :: forall x. Rep (LayerNorm normalizedShape dtype device) x -> LayerNorm normalizedShape dtype device Generic, LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device (LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device))) -> (LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device) -> Parameterized (LayerNorm normalizedShape dtype device) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device forall f. (f -> HList (Parameters f)) -> (f -> HList (Parameters f) -> f) -> Parameterized f $cflattenParameters :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) flattenParameters :: LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) $creplaceParameters :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device replaceParameters :: LayerNorm normalizedShape dtype device -> HList (Parameters (LayerNorm normalizedShape dtype device)) -> LayerNorm normalizedShape dtype device Parameterized) layerNormForward :: forall normalizedShape shape dtype device. ( IsSuffixOf normalizedShape shape, KnownShape normalizedShape ) => LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape layerNormForward :: forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (IsSuffixOf normalizedShape shape, KnownShape normalizedShape) => LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape layerNormForward LayerNorm {Double Parameter device dtype normalizedShape layerNormWeight :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormBias :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Parameter device dtype normalizedShape layerNormEps :: forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). LayerNorm normalizedShape dtype device -> Double layerNormWeight :: Parameter device dtype normalizedShape layerNormBias :: Parameter device dtype normalizedShape layerNormEps :: Double ..} = forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (KnownShape normalizedShape, IsSuffixOf normalizedShape shape) => Tensor device dtype normalizedShape -> Tensor device dtype normalizedShape -> Double -> Tensor device dtype shape -> Tensor device dtype shape layerNorm @normalizedShape (Parameter device dtype normalizedShape -> Tensor device dtype normalizedShape forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent Parameter device dtype normalizedShape layerNormWeight) (Parameter device dtype normalizedShape -> Tensor device dtype normalizedShape forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Parameter device dtype shape -> Tensor device dtype shape toDependent Parameter device dtype normalizedShape layerNormBias) Double layerNormEps instance ( IsSuffixOf normalizedShape shape, KnownShape normalizedShape ) => HasForward (LayerNorm normalizedShape dtype device) (Tensor device dtype shape) (Tensor device dtype shape) where forward :: LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape forward = LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape forall (normalizedShape :: [Nat]) (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (IsSuffixOf normalizedShape shape, KnownShape normalizedShape) => LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape layerNormForward forwardStoch :: LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> IO (Tensor device dtype shape) forwardStoch = (Tensor device dtype shape -> IO (Tensor device dtype shape) forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure (Tensor device dtype shape -> IO (Tensor device dtype shape)) -> (Tensor device dtype shape -> Tensor device dtype shape) -> Tensor device dtype shape -> IO (Tensor device dtype shape) forall b c a. (b -> c) -> (a -> b) -> a -> c .) ((Tensor device dtype shape -> Tensor device dtype shape) -> Tensor device dtype shape -> IO (Tensor device dtype shape)) -> (LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape) -> LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> IO (Tensor device dtype shape) forall b c a. (b -> c) -> (a -> b) -> a -> c . LayerNorm normalizedShape dtype device -> Tensor device dtype shape -> Tensor device dtype shape forall f a b. HasForward f a b => f -> a -> b forward instance ( TensorOptions normalizedShape dtype device, RandDTypeIsValid device dtype ) => Randomizable (LayerNormSpec normalizedShape dtype device) (LayerNorm normalizedShape dtype device) where sample :: LayerNormSpec normalizedShape dtype device -> IO (LayerNorm normalizedShape dtype device) sample LayerNormSpec {Double layerNormEpsSpec :: forall (normalizedShape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). LayerNormSpec normalizedShape dtype device -> Double layerNormEpsSpec :: Double ..} = Parameter device dtype normalizedShape -> Parameter device dtype normalizedShape -> Double -> LayerNorm normalizedShape dtype device forall (device :: (DeviceType, Nat)) (dtype :: DType) (normalizedShape :: [Nat]). Parameter device dtype normalizedShape -> Parameter device dtype normalizedShape -> Double -> LayerNorm normalizedShape dtype device LayerNorm (Parameter device dtype normalizedShape -> Parameter device dtype normalizedShape -> Double -> LayerNorm normalizedShape dtype device) -> IO (Parameter device dtype normalizedShape) -> IO (Parameter device dtype normalizedShape -> Double -> LayerNorm normalizedShape dtype device) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> (Tensor device dtype normalizedShape -> IO (Parameter device dtype normalizedShape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent (Tensor device dtype normalizedShape -> IO (Parameter device dtype normalizedShape)) -> IO (Tensor device dtype normalizedShape) -> IO (Parameter device dtype normalizedShape) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b =<< IO (Tensor device dtype normalizedShape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (TensorOptions shape dtype device, RandDTypeIsValid device dtype) => IO (Tensor device dtype shape) randn) IO (Parameter device dtype normalizedShape -> Double -> LayerNorm normalizedShape dtype device) -> IO (Parameter device dtype normalizedShape) -> IO (Double -> LayerNorm normalizedShape dtype device) forall a b. IO (a -> b) -> IO a -> IO b forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> (Tensor device dtype normalizedShape -> IO (Parameter device dtype normalizedShape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Tensor device dtype shape -> IO (Parameter device dtype shape) makeIndependent (Tensor device dtype normalizedShape -> IO (Parameter device dtype normalizedShape)) -> IO (Tensor device dtype normalizedShape) -> IO (Parameter device dtype normalizedShape) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b =<< IO (Tensor device dtype normalizedShape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). (TensorOptions shape dtype device, RandDTypeIsValid device dtype) => IO (Tensor device dtype shape) randn) IO (Double -> LayerNorm normalizedShape dtype device) -> IO Double -> IO (LayerNorm normalizedShape dtype device) forall a b. IO (a -> b) -> IO a -> IO b forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> Double -> IO Double forall a. a -> IO a forall (f :: * -> *) a. Applicative f => a -> f a pure Double layerNormEpsSpec