{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE DeriveDataTypeable #-} module CodeGen.Types.CLI ( LibType(..) , prefix , describe' , supported , supportedLibraries , outDir , outModule , srcDir , CodeGenType(..) , generatable , TemplateType(..) , generatedTypes ) where import Data.Data import Data.Typeable import CodeGen.Prelude import qualified Data.Text as T import qualified Data.HashSet as HS -- | All possible libraries that we intend to support (these are all src -- libraries in ATen). Note that this ordering is used in codegen and must not be changed. data LibType = ATen | THCUNN | THCS | THC | THNN | THS | TH deriving (Eq, Ord, Show, Enum, Bounded, Read, Generic, Hashable, Data, Typeable) prefix :: LibType -> Bool -> Text prefix lt long = case lt of THC -> if long then "THCuda" else "THC" THCUNN -> if long then "THCuda" else "THC" _ -> tshow lt -- | Short descriptions of each library we intend to support. describe' :: LibType -> String describe' = \case ATen -> "A simple TENsor library thats exposes the Tensor operations in Torch" ++ "and PyTorch directly in C++11." TH -> "Torch7" THC -> "Cuda-based Torch7" THCS -> "Cuda-based Sparse Tensor support with TH" THCUNN -> "Cuda-based THNN" THNN -> "THNN" THS -> "TH Sparse tensor support (ATen library)" -- | Whether or not we currently support code generation for the library supported :: LibType -> Bool supported lt = lt `HS.member` HS.fromList [TH, THC, THNN, THCUNN] supportedLibraries :: [LibType] supportedLibraries = filter supported [minBound..maxBound] -- | Where generated code will be placed. outDir :: LibType -> FilePath outDir lt = intercalate "" [ "output/raw/" , toLowers lt ++ "/" , "src/" , T.unpack (out "/" lt) ] toLowers :: Show a => a -> String toLowers = map toLower . show -- | The prefix of the output module name outModule :: LibType -> Text outModule = out "." out :: Text -> LibType -> Text out x = \case THCUNN -> go2 THC THNN -> go2 TH rest -> go1 rest where go1 lt = T.intercalate x ["Torch","FFI", tshow lt] go2 lt = T.intercalate x ["Torch","FFI", tshow lt, "NN"] -- | Where the source files are located, relative to the root of the hasktorch -- project. srcDir :: LibType -> CodeGenType -> FilePath srcDir lt cgt = intercalate "" [ "./deps/aten/src/" , show lt ++ "/" , if cgt == GenericFiles then "generic/" else "" ] -- | Type of code to generate data CodeGenType = GenericFiles -- ^ generic/ files which are used in C for type-generic code | ConcreteFiles -- ^ concrete supporting files. These include utility -- functions and random generators. deriving (Eq, Ord, Enum, Bounded) instance Read CodeGenType where readsPrec _ s = case s of "generic" -> [(GenericFiles, "")] "concrete" -> [(ConcreteFiles, "")] _ -> [] instance Show CodeGenType where show = \case GenericFiles -> "generic" ConcreteFiles -> "concrete" -- | Whether or not we currently support generating this type of code (ie: I -- (\@stites) am not sure about the managed files). generatable :: CodeGenType -> Bool generatable = const True -- ---------------------------------------- -- Types for representing templating -- ---------------------------------------- data TemplateType = GenByte | GenChar | GenDouble | GenFloat | GenHalf | GenInt | GenLong | GenShort | GenNothing deriving (Eq, Ord, Bounded, Show, Generic, Hashable) -- List used to iterate through all template types generatedTypes :: LibType -> CodeGenType -> [TemplateType] generatedTypes THNN = \case { ConcreteFiles -> [GenNothing]; GenericFiles -> [GenDouble, GenFloat] } generatedTypes THCUNN = \case { ConcreteFiles -> [GenNothing]; GenericFiles -> [GenDouble, GenFloat] } generatedTypes _ = \case ConcreteFiles -> [GenNothing] GenericFiles -> [ GenByte , GenChar , GenDouble , GenFloat , GenHalf , GenInt , GenLong , GenShort ]