{-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TemplateHaskell #-} module NN.Backend.Torch.Codegen where import Control.Applicative import Control.Lens hiding (assign) import Control.Monad.State.Strict import Gen.Caffe.LayerParameter as LP import Language.Lua.PrettyPrinter import Language.Lua.Syntax import Text.Printf import NN.Backend.Torch.Lua import NN.Backend.Torch.Torch data TorchState = TorchState { _statements :: [Stat], _sequential :: Maybe String, _criteria :: [String], _count :: Int } makeLenses ''TorchState newtype Torch a = Torch { _unTorch :: State TorchState a } deriving (Functor, Applicative, Monad, MonadState TorchState) initialize :: Torch () initialize = do seq' <- fresh "seq" sequential ?= seq' statements <>= [require "nn"] statements <>= [assign seq' $ torchExp (TorchModule "nn" "Sequential" [])] where require module' = funCall "require" [toLua $ L module'] fresh :: String -> Torch String fresh prefix = do c <- use count count += 1 return $ printf "%s%d" prefix c insertModule :: Module Exp -> Torch () insertModule (Criterion exp') = do name' <- fresh "criterion" criteria <>= [name'] statements <>= [assign name' exp'] insertModule (Inner exp') = do Just seq' <- use sequential statements <>= [methCall seq' "add" [exp']] finalize :: Torch Block finalize = do Just seq' <- use sequential criteria' <- use criteria statements' <- use statements return $ Block statements' (Just $ return' <$> seq':criteria') runTorch :: [LayerParameter] -> Torch Block runTorch layers = do initialize forM_ exps insertModule finalize where exps = concatMap torchExps layers torchExps lp = (torchExp <$>) <$> torchModules lp lower :: [LayerParameter] -> Block lower layers = (evalState . _unTorch) (runTorch layers) emptyTorch where emptyTorch = TorchState [] Nothing [] 0 codegen :: Block -> String codegen block = pprint block & renderPretty 0.4 150 & displayS & \f -> f ""