{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-| This module declares the ZeroPadding2D layer data type. -}
module TensorSafe.Layers.ZeroPadding2D (ZeroPadding2D) where

import           Data.Kind               (Type)
import           Data.Map
import           Data.Proxy
import           GHC.TypeLits

import           TensorSafe.Compile.Expr
import           TensorSafe.Layer

-- | A ZeroPadding2D layer with padding_rows and padding_cols arguments
data ZeroPadding2D :: Nat -> Nat -> Type where
    ZeroPadding2D :: ZeroPadding2D padding_rows padding_cols
    deriving Show

instance ( KnownNat padding_rows
         , KnownNat padding_cols
         ) => Layer (ZeroPadding2D padding_rows padding_cols) where
    layer = ZeroPadding2D
    compile _ _ =
        let padding_rows = show $ natVal (Proxy :: Proxy padding_rows)
            padding_cols = show $ natVal (Proxy :: Proxy padding_cols)
        in
            CNLayer DZeroPadding2D (fromList [
                ("padding_rows", padding_rows),
                ("padding_cols", padding_cols)
            ])