{-# LANGUAGE TemplateHaskellQuotes #-}

-- |
-- Module      : Language.Halide.Context
-- Description : Helpers to setup inline-c for Halide
-- Copyright   : (c) Tom Westerhout, 2023
--
-- This module defines a Template Haskell function 'importHalide' that sets up everything you need
-- to call Halide functions from 'Language.C.Inline' and 'Language.C.Inlinde.Cpp' quasiquotes.
--
-- We also define two C++ functions:
--
-- > template <class Func>
-- > auto handle_halide_exceptions(Func&& func);
-- >
-- > template <class T>
-- > auto to_string_via_iostream(T const& x) -> std::string*;
--
-- @handle_halide_exceptions@ can be used to catch various Halide exceptions and convert them to
-- [@std::runtime_error@](https://en.cppreference.com/w/cpp/error/runtime_error). It can be used
-- inside 'C.tryBlock' or 'C.catchBlock' to properly re-throw Halide errors.
--
-- @
-- [C.catchBlock| void {
--   handle_halide_exceptions([=]() {
--     Halide::Func f;
--     Halide::Var i;
--     f(i) = *$(Halide::Expr* e);
--     f.realize(Halide::Pipeline::RealizationArg{$(halide_buffer_t* b)});
--   });
-- } |]
-- @
--
-- @to_string_via_iostream@ is a helper that converts a variable into a string by relying on
-- [iostreams](https://en.cppreference.com/w/cpp/io). It returns a pointer to
-- [@std::string@](https://en.cppreference.com/w/cpp/string/basic_string) that it allocated using the @new@
-- keyword. To convert it to a Haskell string, use the 'Language.Halide.Utils.peekCxxString' and
-- 'Language.Halide.Utils.peekAndDeleteCxxString' functions.
module Language.Halide.Context
  ( importHalide
  )
where

import Language.C.Inline qualified as C
import Language.C.Inline.Cpp qualified as C
import Language.C.Types (CIdentifier)
import Language.Halide.Type
import Language.Haskell.TH (DecsQ, Q, TypeQ, lookupTypeName)
import Language.Haskell.TH qualified as TH

-- | One stop function to include all the neccessary machinery to call Halide functions via inline-c.
--
-- Put @importHalide@ somewhere at the beginning of the file and enjoy using the C++ interface of
-- Halide via inline-c quasiquotes.
importHalide :: DecsQ
importHalide :: DecsQ
importHalide =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ Context -> DecsQ
C.context forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Q Context
halideCxt
      , String -> DecsQ
C.include String
"<Halide.h>"
      , -- , C.include "<HalideRuntimeOpenCL.h>"
        -- , C.include "<HalideRuntimeCuda.h>"
        String -> DecsQ
C.include String
"<cxxabi.h>"
      , String -> DecsQ
C.include String
"<dlfcn.h>"
      , DecsQ
defineExceptionHandler
      ]

halideCxt :: Q C.Context
halideCxt :: Q Context
halideCxt = do
  Context
typePairs <- [(CIdentifier, TypeQ)] -> Context
C.cppTypePairs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q [(CIdentifier, TypeQ)]
halideTypePairs
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Context
C.cppCtx forall a. Semigroup a => a -> a -> a
<> Context
C.fptrCtx forall a. Semigroup a => a -> a -> a
<> Context
C.bsCtx forall a. Semigroup a => a -> a -> a
<> Context
typePairs)

halideTypePairs :: Q [(CIdentifier, TypeQ)]
halideTypePairs :: Q [(CIdentifier, TypeQ)]
halideTypePairs = do
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ [Q [(CIdentifier, TypeQ)]
core, Q [(CIdentifier, TypeQ)]
other]
  where
    core :: Q [(CIdentifier, TypeQ)]
core =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        [ (CIdentifier
"Halide::Expr", [t|CxxExpr|])
        , (CIdentifier
"Halide::Var", [t|CxxVar|])
        , (CIdentifier
"Halide::RVar", [t|CxxRVar|])
        , (CIdentifier
"Halide::VarOrRVar", [t|CxxVarOrRVar|])
        , (CIdentifier
"Halide::Func", [t|CxxFunc|])
        , (CIdentifier
"Halide::Internal::Parameter", [t|CxxParameter|])
        , (CIdentifier
"Halide::ImageParam", [t|CxxImageParam|])
        , (CIdentifier
"Halide::Callable", [t|CxxCallable|])
        , (CIdentifier
"Halide::Target", [t|CxxTarget|])
        , (CIdentifier
"Halide::JITUserContext", [t|CxxUserContext|])
        , (CIdentifier
"std::vector", [t|CxxVector|])
        , (CIdentifier
"std::string", [t|CxxString|])
        , (CIdentifier
"halide_type_t", [t|HalideType|])
        ]
    other :: Q [(CIdentifier, TypeQ)]
other =
      [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals
        [ (CIdentifier
"Halide::Internal::Dim", String
"Dim")
        , (CIdentifier
"Halide::Internal::Dimension", String
"CxxDimension")
        , (CIdentifier
"Halide::Internal::FusedPair", String
"FusedPair")
        , (CIdentifier
"Halide::Internal::PrefetchDirective", String
"PrefetchDirective")
        , (CIdentifier
"Halide::Internal::ReductionVariable", String
"ReductionVariable")
        , (CIdentifier
"Halide::Internal::Split", String
"Split")
        , (CIdentifier
"Halide::Internal::StageSchedule", String
"CxxStageSchedule")
        , (CIdentifier
"Halide::Argument", String
"CxxArgument")
        , (CIdentifier
"Halide::Buffer", String
"CxxBuffer")
        , (CIdentifier
"Halide::LoopLevel", String
"CxxLoopLevel")
        , (CIdentifier
"Halide::Stage", String
"CxxStage")
        , (CIdentifier
"Halide::Range", String
"CxxRange")
        , (CIdentifier
"Halide::RDom", String
"CxxRDom")
        , (CIdentifier
"halide_buffer_t", String
"Language.Halide.Buffer.RawHalideBuffer")
        , (CIdentifier
"halide_device_interface_t", String
"HalideDeviceInterface")
        , (CIdentifier
"halide_dimension_t", String
"HalideDimension")
        , (CIdentifier
"halide_trace_event_t", String
"TraceEvent")
        ]
    optional :: (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
    optional :: (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional (CIdentifier
cName, String
hsName) = do
      Maybe Name
hsType <- String -> Q (Maybe Name)
lookupTypeName String
hsName
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (\Name
x -> [(CIdentifier
cName, forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> Type
TH.ConT Name
x))]) Maybe Name
hsType
    optionals :: [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
    optionals :: [(CIdentifier, String)] -> Q [(CIdentifier, TypeQ)]
optionals [(CIdentifier, String)]
pairs = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CIdentifier, String) -> Q [(CIdentifier, TypeQ)]
optional [(CIdentifier, String)]
pairs

defineExceptionHandler :: DecsQ
defineExceptionHandler :: DecsQ
defineExceptionHandler =
  String -> DecsQ
C.verbatim
    String
"\
    \template <class Func>                               \n\
    \auto handle_halide_exceptions(Func&& func) {        \n\
    \  try {                                             \n\
    \    return func();                                  \n\
    \  } catch(Halide::RuntimeError& e) {                \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  } catch(Halide::CompileError& e) {                \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  } catch(Halide::InternalError& e) {               \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  } catch(Halide::Error& e) {                       \n\
    \    throw std::runtime_error{e.what()};             \n\
    \  }                                                 \n\
    \}                                                   \n\
    \                                                    \n\
    \template <class T>                                               \n\
    \auto to_string_via_iostream(T const& x) -> std::string* {        \n\
    \  std::ostringstream stream;                                     \n\
    \  stream << x;                                                   \n\
    \  return new std::string{stream.str()};                          \n\
    \}                                                                \n\
    \\n\
    \namespace Halide { namespace Internal {\n\
    \  std::string print_loop_nest(const std::vector<Function> &);\n\
    \} }\n\
    \"