{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Torch.Typed.NN.Dropout where import GHC.Generics import System.IO.Unsafe import Torch.NN (HasForward (..), Randomizable (..)) import Torch.Typed.Functional import Torch.Typed.Parameter import Torch.Typed.Tensor data DropoutSpec where DropoutSpec :: {DropoutSpec -> Double dropoutProbSpec :: Double} -> DropoutSpec deriving (Int -> DropoutSpec -> ShowS [DropoutSpec] -> ShowS DropoutSpec -> String (Int -> DropoutSpec -> ShowS) -> (DropoutSpec -> String) -> ([DropoutSpec] -> ShowS) -> Show DropoutSpec forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: Int -> DropoutSpec -> ShowS showsPrec :: Int -> DropoutSpec -> ShowS $cshow :: DropoutSpec -> String show :: DropoutSpec -> String $cshowList :: [DropoutSpec] -> ShowS showList :: [DropoutSpec] -> ShowS Show, DropoutSpec -> DropoutSpec -> Bool (DropoutSpec -> DropoutSpec -> Bool) -> (DropoutSpec -> DropoutSpec -> Bool) -> Eq DropoutSpec forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a $c== :: DropoutSpec -> DropoutSpec -> Bool == :: DropoutSpec -> DropoutSpec -> Bool $c/= :: DropoutSpec -> DropoutSpec -> Bool /= :: DropoutSpec -> DropoutSpec -> Bool Eq) data Dropout where Dropout :: {Dropout -> Double dropoutProb :: Double} -> Dropout deriving (Int -> Dropout -> ShowS [Dropout] -> ShowS Dropout -> String (Int -> Dropout -> ShowS) -> (Dropout -> String) -> ([Dropout] -> ShowS) -> Show Dropout forall a. (Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a $cshowsPrec :: Int -> Dropout -> ShowS showsPrec :: Int -> Dropout -> ShowS $cshow :: Dropout -> String show :: Dropout -> String $cshowList :: [Dropout] -> ShowS showList :: [Dropout] -> ShowS Show, (forall x. Dropout -> Rep Dropout x) -> (forall x. Rep Dropout x -> Dropout) -> Generic Dropout forall x. Rep Dropout x -> Dropout forall x. Dropout -> Rep Dropout x forall a. (forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a $cfrom :: forall x. Dropout -> Rep Dropout x from :: forall x. Dropout -> Rep Dropout x $cto :: forall x. Rep Dropout x -> Dropout to :: forall x. Rep Dropout x -> Dropout Generic, Dropout -> HList (Parameters Dropout) Dropout -> HList (Parameters Dropout) -> Dropout (Dropout -> HList (Parameters Dropout)) -> (Dropout -> HList (Parameters Dropout) -> Dropout) -> Parameterized Dropout forall f. (f -> HList (Parameters f)) -> (f -> HList (Parameters f) -> f) -> Parameterized f $cflattenParameters :: Dropout -> HList (Parameters Dropout) flattenParameters :: Dropout -> HList (Parameters Dropout) $creplaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout replaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout Parameterized) dropoutForward :: forall shape dtype device. Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward :: forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward Dropout {Double dropoutProb :: Dropout -> Double dropoutProb :: Double ..} Bool dropoutTrain = Double -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Double -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropout Double dropoutProb Bool dropoutTrain instance HasForward Dropout (Tensor device dtype shape) (Tensor device dtype shape) where forward :: Dropout -> Tensor device dtype shape -> Tensor device dtype shape forward Dropout dropout Tensor device dtype shape input = IO (Tensor device dtype shape) -> Tensor device dtype shape forall a. IO a -> a unsafePerformIO (IO (Tensor device dtype shape) -> Tensor device dtype shape) -> IO (Tensor device dtype shape) -> Tensor device dtype shape forall a b. (a -> b) -> a -> b $ Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward Dropout dropout Bool False Tensor device dtype shape input forwardStoch :: Dropout -> Tensor device dtype shape -> IO (Tensor device dtype shape) forwardStoch Dropout dropout Tensor device dtype shape input = Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) forall (shape :: [Nat]) (dtype :: DType) (device :: (DeviceType, Nat)). Dropout -> Bool -> Tensor device dtype shape -> IO (Tensor device dtype shape) dropoutForward Dropout dropout Bool True Tensor device dtype shape input instance Randomizable DropoutSpec Dropout where sample :: DropoutSpec -> IO Dropout sample DropoutSpec {Double dropoutProbSpec :: DropoutSpec -> Double dropoutProbSpec :: Double ..} = Dropout -> IO Dropout forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return (Dropout -> IO Dropout) -> Dropout -> IO Dropout forall a b. (a -> b) -> a -> b $ Double -> Dropout Dropout Double dropoutProbSpec