{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-| This module declares the Dropout layer data type. -} module TensorSafe.Layers.Dropout (Dropout) where import Data.Kind (Type) import Data.Map import Data.Proxy import GHC.TypeLits import TensorSafe.Compile.Expr import TensorSafe.Layer -- | A Dropout layer with rate and seed arguments data Dropout :: Nat -> Nat -> Type where Dropout :: Dropout rate seed deriving Show instance (KnownNat rate, KnownNat seed) => Layer (Dropout rate seed) where layer = Dropout compile _ _ = let rate = show $ natVal (Proxy :: Proxy rate) seed = show $ natVal (Proxy :: Proxy seed) in CNLayer DDropout (fromList [ ("rate", rate), ("seed", seed) ])