{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Torch.Typed.NN.Dropout where

import GHC.Generics
import System.IO.Unsafe
import Torch.NN (HasForward (..), Randomizable (..))
import Torch.Typed.Functional
import Torch.Typed.Parameter
import Torch.Typed.Tensor

data DropoutSpec where
  DropoutSpec ::
    {DropoutSpec -> Double
dropoutProbSpec :: Double} ->
    DropoutSpec
  deriving (Int -> DropoutSpec -> ShowS
[DropoutSpec] -> ShowS
DropoutSpec -> String
(Int -> DropoutSpec -> ShowS)
-> (DropoutSpec -> String)
-> ([DropoutSpec] -> ShowS)
-> Show DropoutSpec
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DropoutSpec -> ShowS
showsPrec :: Int -> DropoutSpec -> ShowS
$cshow :: DropoutSpec -> String
show :: DropoutSpec -> String
$cshowList :: [DropoutSpec] -> ShowS
showList :: [DropoutSpec] -> ShowS
Show, DropoutSpec -> DropoutSpec -> Bool
(DropoutSpec -> DropoutSpec -> Bool)
-> (DropoutSpec -> DropoutSpec -> Bool) -> Eq DropoutSpec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DropoutSpec -> DropoutSpec -> Bool
== :: DropoutSpec -> DropoutSpec -> Bool
$c/= :: DropoutSpec -> DropoutSpec -> Bool
/= :: DropoutSpec -> DropoutSpec -> Bool
Eq)

data Dropout where
  Dropout ::
    {Dropout -> Double
dropoutProb :: Double} ->
    Dropout
  deriving (Int -> Dropout -> ShowS
[Dropout] -> ShowS
Dropout -> String
(Int -> Dropout -> ShowS)
-> (Dropout -> String) -> ([Dropout] -> ShowS) -> Show Dropout
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Dropout -> ShowS
showsPrec :: Int -> Dropout -> ShowS
$cshow :: Dropout -> String
show :: Dropout -> String
$cshowList :: [Dropout] -> ShowS
showList :: [Dropout] -> ShowS
Show, (forall x. Dropout -> Rep Dropout x)
-> (forall x. Rep Dropout x -> Dropout) -> Generic Dropout
forall x. Rep Dropout x -> Dropout
forall x. Dropout -> Rep Dropout x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Dropout -> Rep Dropout x
from :: forall x. Dropout -> Rep Dropout x
$cto :: forall x. Rep Dropout x -> Dropout
to :: forall x. Rep Dropout x -> Dropout
Generic, Dropout -> HList (Parameters Dropout)
Dropout -> HList (Parameters Dropout) -> Dropout
(Dropout -> HList (Parameters Dropout))
-> (Dropout -> HList (Parameters Dropout) -> Dropout)
-> Parameterized Dropout
forall f.
(f -> HList (Parameters f))
-> (f -> HList (Parameters f) -> f) -> Parameterized f
$cflattenParameters :: Dropout -> HList (Parameters Dropout)
flattenParameters :: Dropout -> HList (Parameters Dropout)
$creplaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout
replaceParameters :: Dropout -> HList (Parameters Dropout) -> Dropout
Parameterized)

dropoutForward ::
  forall shape dtype device.
  Dropout ->
  Bool ->
  Tensor device dtype shape ->
  IO (Tensor device dtype shape)
dropoutForward :: forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout {Double
dropoutProb :: Dropout -> Double
dropoutProb :: Double
..} Bool
dropoutTrain = Double
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Double
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropout Double
dropoutProb Bool
dropoutTrain

instance HasForward Dropout (Tensor device dtype shape) (Tensor device dtype shape) where
  forward :: Dropout -> Tensor device dtype shape -> Tensor device dtype shape
forward Dropout
dropout Tensor device dtype shape
input = IO (Tensor device dtype shape) -> Tensor device dtype shape
forall a. IO a -> a
unsafePerformIO (IO (Tensor device dtype shape) -> Tensor device dtype shape)
-> IO (Tensor device dtype shape) -> Tensor device dtype shape
forall a b. (a -> b) -> a -> b
$ Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout Bool
False Tensor device dtype shape
input
  forwardStoch :: Dropout
-> Tensor device dtype shape -> IO (Tensor device dtype shape)
forwardStoch Dropout
dropout Tensor device dtype shape
input = Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
forall (shape :: [Nat]) (dtype :: DType)
       (device :: (DeviceType, Nat)).
Dropout
-> Bool
-> Tensor device dtype shape
-> IO (Tensor device dtype shape)
dropoutForward Dropout
dropout Bool
True Tensor device dtype shape
input

instance Randomizable DropoutSpec Dropout where
  sample :: DropoutSpec -> IO Dropout
sample DropoutSpec {Double
dropoutProbSpec :: DropoutSpec -> Double
dropoutProbSpec :: Double
..} = Dropout -> IO Dropout
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Dropout -> IO Dropout) -> Dropout -> IO Dropout
forall a b. (a -> b) -> a -> b
$ Double -> Dropout
Dropout Double
dropoutProbSpec