{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-| This module declares the Dense, a.k.a. FullyConnected, layer data type. -}
module TensorSafe.Layers.Dense where

import           Data.Kind               (Type)
import           Data.Map
import           Data.Proxy
import           GHC.TypeLits

import           TensorSafe.Compile.Expr
import           TensorSafe.Layer


-- | A classic Dense, or FullyConnected, layer with input and output parameters.
data Dense :: Nat -> Nat -> Type where
    Dense :: Dense input output
    deriving Show

instance (KnownNat input, KnownNat output) => Layer (Dense input output) where
    layer = Dense
    compile _ _ =
        let input = show $ natVal (Proxy :: Proxy input)
            output = show $ natVal (Proxy :: Proxy output)
        in
            CNLayer DDense (fromList [
              ("inputDim", input),
              ("units", output)
            ])