{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-| This module declares the BatchNormalization layer data type. -} module TensorSafe.Layers.BatchNormalization where import Data.Kind (Type) import Data.Map import Data.Proxy import GHC.TypeLits import TensorSafe.Compile.Expr import TensorSafe.Layer -- | A classic BatchNormalization layer with axis, momentum and epsilon parameters data BatchNormalization :: Nat -> Nat -> Nat -> Type where BatchNormalization :: BatchNormalization axis momentum epsilon deriving Show instance ( KnownNat axis , KnownNat momentum , KnownNat epsilon ) => Layer (BatchNormalization axis momentum epsilon) where layer = BatchNormalization compile _ _ = let axis = show $ natVal (Proxy :: Proxy axis) momentum = show $ natVal (Proxy :: Proxy momentum) epsilon = show $ natVal (Proxy :: Proxy epsilon) in CNLayer DBatchNormalization (fromList [ ("axis", axis), ("epsilon", epsilon), ("momentum", momentum) ])