{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module TensorSafe.Layers.BatchNormalization where
import Data.Kind (Type)
import Data.Map
import Data.Proxy
import GHC.TypeLits
import TensorSafe.Compile.Expr
import TensorSafe.Layer
data BatchNormalization :: Nat -> Nat -> Nat -> Type where
BatchNormalization :: BatchNormalization axis momentum epsilon
deriving Show
instance ( KnownNat axis
, KnownNat momentum
, KnownNat epsilon
) => Layer (BatchNormalization axis momentum epsilon) where
layer = BatchNormalization
compile _ _ =
let axis = show $ natVal (Proxy :: Proxy axis)
momentum = show $ natVal (Proxy :: Proxy momentum)
epsilon = show $ natVal (Proxy :: Proxy epsilon)
in
CNLayer DBatchNormalization (fromList [
("axis", axis),
("epsilon", epsilon),
("momentum", momentum)
])