{-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} module MXNet.NN.Layer where import MXNet.Core.Types.Internal import MXNet.Core.Base.HMap import qualified MXNet.Core.Base.Internal.TH.Symbol as S import qualified MXNet.Core.Base.Internal as I import MXNet.NN.Utils variable :: String -> IO SymbolHandle variable = I.checked . I.mxSymbolCreateVariable convolution :: (MatchKVList kvs '["stride" ':= String, "dilate" ':= String, "pad" ':= String, "num_group" ':= Int, "workspace" ':= Int, "no_bias" ':= Bool, "cudnn_tune" ':= String, "cudnn_off" ':= Bool, "layout" ':= String] ,ShowKV kvs) => String -> SymbolHandle -> [Int] -> Int -> HMap kvs -> IO SymbolHandle convolution name dat kernel_shape num_filter args = do w <- variable (name ++ "-w") b <- variable (name ++ "-b") S.convolution name dat w b (formatShape kernel_shape) num_filter args fullyConnected :: (MatchKVList kvs '["no_bias" ':= Bool, "flatten" ':= Bool] ,ShowKV kvs) => String -> SymbolHandle -> Int -> HMap kvs -> IO SymbolHandle fullyConnected name dat num_neuron args = do w <- variable (name ++ "-w") b <- variable (name ++ "-b") S.fullyconnected name dat w b num_neuron args data PoolingMethod = PoolingMax | PoolingAvg | PoolingSum poolingMethodToStr :: PoolingMethod -> String poolingMethodToStr PoolingMax = "max" poolingMethodToStr PoolingAvg = "avg" poolingMethodToStr PoolingSum = "sum" pooling :: (MatchKVList kvs '["global_pool" ':= Bool, "cudnn_off" ':= Bool, "pooling_convention" ':= String, "stride" ':= String, "pad" ':= String] ,ShowKV kvs) => String -> SymbolHandle -> [Int] -> PoolingMethod -> HMap kvs -> IO SymbolHandle pooling name input shape method args = S.pooling name input (formatShape shape) (poolingMethodToStr method) args flatten :: String -> SymbolHandle -> IO SymbolHandle flatten = S.flatten data ActivationType = Relu | Sigmoid | Tanh | SoftRelu activationTypeToStr :: ActivationType -> String activationTypeToStr Relu = "relu" activationTypeToStr Sigmoid = "sigmoid" activationTypeToStr Tanh = "tanh" activationTypeToStr SoftRelu = "softrelu" activation :: String -> SymbolHandle -> ActivationType -> IO SymbolHandle activation name input typ = S.activation name input (activationTypeToStr typ) softmaxoutput :: (MatchKVList kvs '["grad_scale" ':= Float, "ignore_label" ':= Float, "multi_output" ':= Bool, "use_ignore" ':= Bool, "preserve_shape" ':= Bool, "normalization" ':= String, "out_grad" ':= Bool, "smooth_alpha" ':= Float], ShowKV kvs) => String -> SymbolHandle -> SymbolHandle -> HMap kvs -> IO SymbolHandle softmaxoutput = S.softmaxoutput