{-# LANGUAGE TemplateHaskellQuotes #-}
module Language.Halide.Context
  ( importHalide
  )
where
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp qualified as C
import Language.C.Types (CIdentifier)
import Language.Halide.Type
import Language.Haskell.TH (DecsQ, Q, TypeQ, lookupTypeName)
import Language.Haskell.TH qualified as TH
importHalide :: DecsQ
importHalide :: DecsQ
importHalide =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ Context -> DecsQ
C.context forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Q Context
halideCxt
      , String -> DecsQ
C.include String
"<Halide.h>"
      , 
        
        String -> DecsQ
C.include String
"<cxxabi.h>"
      , String -> DecsQ
C.include String
"<dlfcn.h>"
      , DecsQ
defineExceptionHandler
      ]
halideCxt :: Q C.Context
halideCxt :: Q Context
halideCxt = do
  Context
typePairs <- [(CIdentifier, TypeQ)] -> Context
C.cppTypePairs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q [(CIdentifier, TypeQ)]
halideTypePairs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Context
C.cppCtx forall a. Semigroup a => a -> a -> a
<> Context
C.fptrCtx forall a. Semigroup a => a -> a -> a
<> Context
C.bsCtx forall a. Semigroup a => a -> a -> a
<> Context
typePairs)
halideTypePairs :: Q [(CIdentifier, TypeQ)]
halideTypePairs :: Q [(CIdentifier, TypeQ)]
halideTypePairs = do
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ [Q [(CIdentifier, TypeQ)]
core, Q [(CIdentifier, TypeQ)]
other]
  where
    core :: Q [(CIdentifier, TypeQ)]
core =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        [ (CIdentifier
"Halide::Expr", [t|CxxExpr|])
        , (CIdentifier
"Halide::Var", [t|CxxVar|])
        , (CIdentifier
"Halide::RVar", [t|CxxRVar|])
        , (CIdentifier
"Halide::VarOrRVar", [t|CxxVarOrRVar|])
        , (CIdentifier
"Halide::Func", [t|CxxFunc|])
        , (CIdentifier
"Halide::Internal::Parameter", [t|CxxParameter|])
        , (CIdentifier
"Halide::ImageParam", [t|CxxImageParam|])
        , (CIdentifier
"Halide::Callable", [t|CxxCallable|])
        , (CIdentifier
"Halide::Target", [t|CxxTarget|])
        , (CIdentifier
"Halide::JITUserContext", [t|CxxUserContext|])
        , (CIdentifier
"std::vector", [t|CxxVector|])
        , (CIdentifier
"std::string", [t|CxxString|])
        , (CIdentifier
"halide_type_t", [t|HalideType|])
        ]
    other :: Q [(CIdentifier, TypeQ)]
other =
      [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals
        [ (CIdentifier
"Halide::Internal::Dim", String
"Dim")
        , (CIdentifier
"Halide::Internal::Dimension", String
"CxxDimension")
        , (CIdentifier
"Halide::Internal::FusedPair", String
"FusedPair")
        , (CIdentifier
"Halide::Internal::PrefetchDirective", String
"PrefetchDirective")
        , (CIdentifier
"Halide::Internal::ReductionVariable", String
"ReductionVariable")
        , (CIdentifier
"Halide::Internal::Split", String
"Split")
        , (CIdentifier
"Halide::Internal::StageSchedule", String
"CxxStageSchedule")
        , (CIdentifier
"Halide::Argument", String
"CxxArgument")
        , (CIdentifier
"Halide::Buffer", String
"CxxBuffer")
        , (CIdentifier
"Halide::LoopLevel", String
"CxxLoopLevel")
        , (CIdentifier
"Halide::Stage", String
"CxxStage")
        , (CIdentifier
"Halide::Range", String
"CxxRange")
        , (CIdentifier
"Halide::RDom", String
"CxxRDom")
        , (CIdentifier
"halide_buffer_t", String
"Language.Halide.Buffer.RawHalideBuffer")
        , (CIdentifier
"halide_device_interface_t", String
"HalideDeviceInterface")
        , (CIdentifier
"halide_dimension_t", String
"HalideDimension")
        , (CIdentifier
"halide_trace_event_t", String
"TraceEvent")
        ]
    optional :: (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
    optional :: (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional (CIdentifier
cName, String
hsName) = do
      Maybe Name
hsType <- String -> Q (Maybe Name)
lookupTypeName String
hsName
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\Name
x -> [(CIdentifier
cName, forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Type
TH.ConT Name
x))]) Maybe Name
hsType
    optionals :: [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
    optionals :: [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals [(CIdentifier, String)]
pairs = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional [(CIdentifier, String)]
pairs
defineExceptionHandler :: DecsQ
defineExceptionHandler :: DecsQ
defineExceptionHandler =
  String -> DecsQ
C.verbatim
    String
"\
    \template <class Func>                               \n\
    \auto handle_halide_exceptions(Func&& func) {        \n\
    \  try {                                             \n\
    \    return func();                                  \n\
    \  } catch(Halide::RuntimeError& e) {                \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  } catch(Halide::CompileError& e) {                \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  } catch(Halide::InternalError& e) {               \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  } catch(Halide::Error& e) {                       \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  }                                                 \n\
    \}                                                   \n\
    \                                                    \n\
    \template <class T>                                               \n\
    \auto to_string_via_iostream(T const& x) -> std::string* {        \n\
    \  std::ostringstream stream;                                     \n\
    \  stream << x;                                                   \n\
    \  return new std::string{stream.str()};                          \n\
    \}                                                                \n\
    \\n\
    \namespace Halide { namespace Internal {\n\
    \  std::string print_loop_nest(const std::vector<Function> &);\n\
    \} }\n\
    \"