{-# LANGUAGE DeriveFunctor #-} module NN.Backend.Torch.Torch where import Gen.Caffe.ConvolutionParameter as CP import Gen.Caffe.DropoutParameter as DP import Gen.Caffe.InnerProductParameter as IP import Gen.Caffe.LayerParameter as LP import Gen.Caffe.PoolingParameter as PP import Gen.Caffe.PoolingParameter.PoolMethod as PP import Control.Applicative import Control.Lens import Data.Graph.Inductive.Graph hiding ((&)) import Data.Graph.Inductive.Query import Language.Lua.Syntax import NN.Backend.Torch.Lua import NN.DSL -- Modules are either sequential or criterion - which are treated -- differently by Torch data Module a = Criterion a | Inner a deriving (Functor, Show) data TorchModule = TorchModule Name Name [Exp] deriving (Show) torchExp :: TorchModule -> Exp torchExp module' = PrefixExp (PEFunCall (construct module')) where construct (TorchModule luaModule torchModule args) = NormalFunCall (PEVar (SelectName (var luaModule) torchModule)) (Args args) torchModules :: LayerParameter -> [Module TorchModule] torchModules lp = go (layerTy lp) where nn name' args = Inner $ TorchModule "nn" name' (toLua <$> args) criterion name' = Criterion $ TorchModule "nn" name' [] nn' name' = nn name' ([] :: [Float]) -- Ugly case anaysis, sorry. go Pool = [nn ty' [kW, kH, dW, dH]] where kW = poolP PP._kernel_size kH = kW dW = poolP PP._stride dH = dW ty' = case poolP PP._pool of Just MAX -> "SpatialMaxPooling" Just AVE -> "SpatialAveragePooling" _ -> error "Unsupported Pooling Type" poolP f = lp ^. LP._pooling_param ^? _Just . f . _Just go Conv = [nn "SpatialConvolutionMM" [nInputPlane, nOutputPlane, kW, kH, dW, dH, padding]] where kW = convP CP._kernel_size kH = kW dW = convP CP._stride dH = dW padding = convP CP._pad -- TODO - propagation pass to size the layers nInputPlane = Nothing nOutputPlane = convP CP._num_output convP f = lp ^. LP._convolution_param ^? _Just . f . _Just go ReLU = [nn' "Threshold"] go IP = [nn "Linear" [nInput, nOutput]] where -- TODO - propagation pass to size the layers nInput = Nothing nOutput = lp ^. LP._inner_product_param ^? _Just . IP._num_output . _Just go Dropout = [nn "Dropout" [ratio]] where Just ratio = lp ^. LP._dropout_param ^? _Just . DP._dropout_ratio . _Just go SoftmaxWithLoss = [nn' "LogSoftMax", criterion "ClassNLLCriterion"] go ty' = error $ "Unhandled layer type: " ++ show ty' torchLayers :: [LayerTy] torchLayers = [Pool, Conv, ReLU, IP, Dropout, SoftmaxWithLoss] -- Graph validation -- A graph is `sequential` if and only if -- - It has n-1 edges -- - It is connected -- - Every node has an out degree of zero or one. isSequential :: Net -> Bool isSequential gr = e == (n-1) && length (dff' gr) == 1 && and [l `elem` [0, 1] | i <- nodes gr, let l = (length . suc gr) i] where e = length (edges gr) n = length (nodes gr) clean :: Net -> Net clean gr = foldl (flip delNode) gr toDelete where toDelete = filter (\n -> layerTy (label n) `notElem` torchLayers) (nodes gr) label n = lab' (context gr n) linearize :: Net -> Maybe [LayerParameter] linearize gr = if isSequential gr then Just (topsort' gr) else Nothing