{-# LANGUAGE TemplateHaskellQuotes #-}
module Language.Halide.Context
( importHalide
)
where
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp qualified as C
import Language.C.Types (CIdentifier)
import Language.Halide.Type
import Language.Haskell.TH (DecsQ, Q, TypeQ, lookupTypeName)
import Language.Haskell.TH qualified as TH
importHalide :: DecsQ
importHalide :: DecsQ
importHalide =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
[ Context -> DecsQ
C.context forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Q Context
halideCxt
, String -> DecsQ
C.include String
"<Halide.h>"
,
String -> DecsQ
C.include String
"<cxxabi.h>"
, String -> DecsQ
C.include String
"<dlfcn.h>"
, DecsQ
defineExceptionHandler
]
halideCxt :: Q C.Context
halideCxt :: Q Context
halideCxt = do
Context
typePairs <- [(CIdentifier, TypeQ)] -> Context
C.cppTypePairs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q [(CIdentifier, TypeQ)]
halideTypePairs
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Context
C.cppCtx forall a. Semigroup a => a -> a -> a
<> Context
C.fptrCtx forall a. Semigroup a => a -> a -> a
<> Context
C.bsCtx forall a. Semigroup a => a -> a -> a
<> Context
typePairs)
halideTypePairs :: Q [(CIdentifier, TypeQ)]
halideTypePairs :: Q [(CIdentifier, TypeQ)]
halideTypePairs = do
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ [Q [(CIdentifier, TypeQ)]
core, Q [(CIdentifier, TypeQ)]
other]
where
core :: Q [(CIdentifier, TypeQ)]
core =
forall (f :: * -> *) a. Applicative f => a -> f a
pure
[ (CIdentifier
"Halide::Expr", [t|CxxExpr|])
, (CIdentifier
"Halide::Var", [t|CxxVar|])
, (CIdentifier
"Halide::RVar", [t|CxxRVar|])
, (CIdentifier
"Halide::VarOrRVar", [t|CxxVarOrRVar|])
, (CIdentifier
"Halide::Func", [t|CxxFunc|])
, (CIdentifier
"Halide::Internal::Parameter", [t|CxxParameter|])
, (CIdentifier
"Halide::ImageParam", [t|CxxImageParam|])
, (CIdentifier
"Halide::Callable", [t|CxxCallable|])
, (CIdentifier
"Halide::Target", [t|CxxTarget|])
, (CIdentifier
"Halide::JITUserContext", [t|CxxUserContext|])
, (CIdentifier
"std::vector", [t|CxxVector|])
, (CIdentifier
"std::string", [t|CxxString|])
, (CIdentifier
"halide_type_t", [t|HalideType|])
]
other :: Q [(CIdentifier, TypeQ)]
other =
[(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals
[ (CIdentifier
"Halide::Internal::Dim", String
"Dim")
, (CIdentifier
"Halide::Internal::Dimension", String
"CxxDimension")
, (CIdentifier
"Halide::Internal::FusedPair", String
"FusedPair")
, (CIdentifier
"Halide::Internal::PrefetchDirective", String
"PrefetchDirective")
, (CIdentifier
"Halide::Internal::ReductionVariable", String
"ReductionVariable")
, (CIdentifier
"Halide::Internal::Split", String
"Split")
, (CIdentifier
"Halide::Internal::StageSchedule", String
"CxxStageSchedule")
, (CIdentifier
"Halide::Argument", String
"CxxArgument")
, (CIdentifier
"Halide::Buffer", String
"CxxBuffer")
, (CIdentifier
"Halide::LoopLevel", String
"CxxLoopLevel")
, (CIdentifier
"Halide::Stage", String
"CxxStage")
, (CIdentifier
"Halide::Range", String
"CxxRange")
, (CIdentifier
"Halide::RDom", String
"CxxRDom")
, (CIdentifier
"halide_buffer_t", String
"Language.Halide.Buffer.RawHalideBuffer")
, (CIdentifier
"halide_device_interface_t", String
"HalideDeviceInterface")
, (CIdentifier
"halide_dimension_t", String
"HalideDimension")
, (CIdentifier
"halide_trace_event_t", String
"TraceEvent")
]
optional :: (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional :: (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional (CIdentifier
cName, String
hsName) = do
Maybe Name
hsType <- String -> Q (Maybe Name)
lookupTypeName String
hsName
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\Name
x -> [(CIdentifier
cName, forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Type
TH.ConT Name
x))]) Maybe Name
hsType
optionals :: [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals :: [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals [(CIdentifier, String)]
pairs = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional [(CIdentifier, String)]
pairs
defineExceptionHandler :: DecsQ
defineExceptionHandler :: DecsQ
defineExceptionHandler =
String -> DecsQ
C.verbatim
String
"\
\template <class Func> \n\
\auto handle_halide_exceptions(Func&& func) { \n\
\ try { \n\
\ return func(); \n\
\ } catch(Halide::RuntimeError& e) { \n\
\ throw std::runtime_error{e.what()}; \n\
\ } catch(Halide::CompileError& e) { \n\
\ throw std::runtime_error{e.what()}; \n\
\ } catch(Halide::InternalError& e) { \n\
\ throw std::runtime_error{e.what()}; \n\
\ } catch(Halide::Error& e) { \n\
\ throw std::runtime_error{e.what()}; \n\
\ } \n\
\} \n\
\ \n\
\template <class T> \n\
\auto to_string_via_iostream(T const& x) -> std::string* { \n\
\ std::ostringstream stream; \n\
\ stream << x; \n\
\ return new std::string{stream.str()}; \n\
\} \n\
\\n\
\namespace Halide { namespace Internal {\n\
\ std::string print_loop_nest(const std::vector<Function> &);\n\
\} }\n\
\"