{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module TensorSafe.Layers.ZeroPadding2D (ZeroPadding2D) where
import Data.Kind (Type)
import Data.Map
import Data.Proxy
import GHC.TypeLits
import TensorSafe.Compile.Expr
import TensorSafe.Layer
data ZeroPadding2D :: Nat -> Nat -> Type where
ZeroPadding2D :: ZeroPadding2D padding_rows padding_cols
deriving Show
instance ( KnownNat padding_rows
, KnownNat padding_cols
) => Layer (ZeroPadding2D padding_rows padding_cols) where
layer = ZeroPadding2D
compile _ _ =
let padding_rows = show $ natVal (Proxy :: Proxy padding_rows)
padding_cols = show $ natVal (Proxy :: Proxy padding_cols)
in
CNLayer DZeroPadding2D (fromList [
("padding_rows", padding_rows),
("padding_cols", padding_cols)
])