{-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module CodeGen.Types.HsOutput ( HModule(..) , ModuleSuffix(..) , FileSuffix(..) , TextPath(..) , makeModule , TypeCategory(..) , FunctionName(..) , CRep(..) , HsRep(..) , stripModule , CTensor(..) , CReal(..) , CAccReal(..) , CStorage(..) , HasAlias(..) ) where import CodeGen.Prelude import CodeGen.Types.CLI import CodeGen.Types.Parsed import Data.Semigroup hiding ((<>)) import qualified Data.Text as T -- ---------------------------------------- -- Types for rendering output -- ---------------------------------------- data HModule = HModule { lib :: LibType , extensions :: [Text] , imports :: [Text] , typeDefs :: [(Text, Text)] , header :: FilePath , typeTemplate :: TemplateType , suffix :: ModuleSuffix , fileSuffix :: FileSuffix , bindings :: [Function] , modOutDir :: TextPath , isTemplate :: CodeGenType } deriving Show newtype ModuleSuffix = ModuleSuffix { textSuffix :: Text } deriving newtype (IsString, Semigroup, Monoid, Ord, Read, Eq, Show) newtype FileSuffix = FileSuffix { textFileSuffix :: Text } deriving newtype (IsString, Semigroup, Monoid, Ord, Read, Eq, Show) newtype TextPath = TextPath { textPath :: Text } deriving newtype (IsString, Semigroup, Monoid, Ord, Read, Eq, Show) makeModule :: LibType -> TextPath -> CodeGenType -> FilePath -> ModuleSuffix -> FileSuffix -> TemplateType -> [Function] -> HModule makeModule lt a0 a1 a2 a3 a4 a5 a6 = HModule { lib = lt , extensions = ["ForeignFunctionInterface"] , imports = ["Foreign", "Foreign.C.Types", "Data.Word", "Data.Int"] <> torchtypes , typeDefs = [] , modOutDir = a0 , isTemplate = a1 , header = a2 , suffix = a3 , fileSuffix = a4 , typeTemplate = a5 , bindings = a6 } where torchtypes :: [Text] torchtypes = case lt of THC -> go [TH, THC] THCUNN -> go [TH, THC] THNN -> go [TH] rest -> go [rest] where go :: [LibType] -> [Text] go ls = (("Torch.Types." <>) . tshow) <$> ls data TypeCategory = ReturnValue | FunctionParam ------------------------------------------------------------------------------- -- | a concrete type for function names newtype FunctionName = FunctionName Text deriving stock (Show, Eq, Ord) deriving newtype (IsString, Hashable) newtype CRep = CRep Text deriving stock (Show, Eq, Ord) deriving newtype (IsString, Hashable) newtype HsRep = HsRep Text deriving stock (Show, Eq, Ord) deriving newtype (IsString, Hashable) stripModule :: HsRep -> Text stripModule (HsRep t) = T.takeWhileEnd (/= '.') t data CTensor = CTensor HsRep CRep deriving (Eq, Ord, Generic, Hashable, Show) data CReal = CReal HsRep CRep deriving (Eq, Ord, Generic, Hashable, Show) data CAccReal = CAccReal HsRep CRep deriving (Eq, Ord, Generic, Hashable, Show) data CStorage = CStorage HsRep CRep deriving (Eq, Ord, Generic, Hashable, Show) -- ========================================================================= -- class HasAlias t where alias :: t -> Text instance HasAlias CTensor where alias (CTensor t _) = _alias "CTensor" t instance HasAlias CReal where alias (CReal t _) = _alias "CReal" t instance HasAlias CAccReal where alias (CAccReal t _) = _alias "CAccReal" t instance HasAlias CStorage where alias (CStorage t _) = _alias "CStorage" t _alias :: Text -> HsRep -> Text _alias a (HsRep t) = T.intercalate " " ["type", a, "=", t]