{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
module Language.Halide.Func
(
Func (..)
, FuncTy (..)
, Stage (..)
, Function
, Parameter
, buffer
, scalar
, define
, (!)
, realizeOnTarget
, realize
, Schedulable (..)
, TailStrategy (..)
, computeRoot
, getStage
, getLoopLevel
, getLoopLevelAtStage
, asUsed
, asUsedBy
, copyToDevice
, copyToHost
, storeAt
, computeAt
, dim
, estimate
, bound
, getArgs
, update
, hasUpdateDefinitions
, getUpdateStage
, prettyLoopNest
, asBufferParam
, withFunc
, withCxxFunc
, withBufferParam
, wrapCxxFunc
, CxxStage
, wrapCxxStage
, withCxxStage
)
where
import Control.Exception (bracket)
import Control.Monad (forM)
import Data.Constraint
import Data.Functor ((<&>))
import Data.IORef
import Data.Kind (Type)
import Data.Proxy
import Data.Text (Text)
import Data.Text.Encoding qualified as T
import Foreign.ForeignPtr
import Foreign.Marshal (toBool, with)
import Foreign.Ptr (Ptr, castPtr)
import GHC.Stack (HasCallStack)
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Dimension
import Language.Halide.Expr
import Language.Halide.LoopLevel
import Language.Halide.Target
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Unsafe.Coerce
import Prelude hiding (min, tail)
data CxxStage
importHalide
data Func (t :: FuncTy) (n :: Nat) (a :: Type) where
Func :: {-# UNPACK #-} !(ForeignPtr CxxFunc) -> Func 'FuncTy n a
Param :: IsHalideType a => {-# UNPACK #-} !(IORef (Maybe (ForeignPtr CxxImageParam))) -> Func 'ParamTy n (Expr a)
data FuncTy = FuncTy | ParamTy
deriving stock (Int -> FuncTy -> ShowS
[FuncTy] -> ShowS
FuncTy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FuncTy] -> ShowS
$cshowList :: [FuncTy] -> ShowS
show :: FuncTy -> String
$cshow :: FuncTy -> String
showsPrec :: Int -> FuncTy -> ShowS
$cshowsPrec :: Int -> FuncTy -> ShowS
Show, FuncTy -> FuncTy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuncTy -> FuncTy -> Bool
$c/= :: FuncTy -> FuncTy -> Bool
== :: FuncTy -> FuncTy -> Bool
$c== :: FuncTy -> FuncTy -> Bool
Eq, Eq FuncTy
FuncTy -> FuncTy -> Bool
FuncTy -> FuncTy -> Ordering
FuncTy -> FuncTy -> FuncTy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: FuncTy -> FuncTy -> FuncTy
$cmin :: FuncTy -> FuncTy -> FuncTy
max :: FuncTy -> FuncTy -> FuncTy
$cmax :: FuncTy -> FuncTy -> FuncTy
>= :: FuncTy -> FuncTy -> Bool
$c>= :: FuncTy -> FuncTy -> Bool
> :: FuncTy -> FuncTy -> Bool
$c> :: FuncTy -> FuncTy -> Bool
<= :: FuncTy -> FuncTy -> Bool
$c<= :: FuncTy -> FuncTy -> Bool
< :: FuncTy -> FuncTy -> Bool
$c< :: FuncTy -> FuncTy -> Bool
compare :: FuncTy -> FuncTy -> Ordering
$ccompare :: FuncTy -> FuncTy -> Ordering
Ord)
type Function n a = Func 'FuncTy n (Expr a)
type Parameter n a = Func 'ParamTy n (Expr a)
newtype Stage (n :: Nat) (a :: Type) = Stage (ForeignPtr CxxStage)
data TailStrategy
=
TailRoundUp
|
TailGuardWithIf
|
TailPredicate
|
TailPredicateLoads
|
TailPredicateStores
|
TailShiftInwards
|
TailAuto
deriving stock (TailStrategy -> TailStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TailStrategy -> TailStrategy -> Bool
$c/= :: TailStrategy -> TailStrategy -> Bool
== :: TailStrategy -> TailStrategy -> Bool
$c== :: TailStrategy -> TailStrategy -> Bool
Eq, Eq TailStrategy
TailStrategy -> TailStrategy -> Bool
TailStrategy -> TailStrategy -> Ordering
TailStrategy -> TailStrategy -> TailStrategy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TailStrategy -> TailStrategy -> TailStrategy
$cmin :: TailStrategy -> TailStrategy -> TailStrategy
max :: TailStrategy -> TailStrategy -> TailStrategy
$cmax :: TailStrategy -> TailStrategy -> TailStrategy
>= :: TailStrategy -> TailStrategy -> Bool
$c>= :: TailStrategy -> TailStrategy -> Bool
> :: TailStrategy -> TailStrategy -> Bool
$c> :: TailStrategy -> TailStrategy -> Bool
<= :: TailStrategy -> TailStrategy -> Bool
$c<= :: TailStrategy -> TailStrategy -> Bool
< :: TailStrategy -> TailStrategy -> Bool
$c< :: TailStrategy -> TailStrategy -> Bool
compare :: TailStrategy -> TailStrategy -> Ordering
$ccompare :: TailStrategy -> TailStrategy -> Ordering
Ord, Int -> TailStrategy -> ShowS
[TailStrategy] -> ShowS
TailStrategy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TailStrategy] -> ShowS
$cshowList :: [TailStrategy] -> ShowS
show :: TailStrategy -> String
$cshow :: TailStrategy -> String
showsPrec :: Int -> TailStrategy -> ShowS
$cshowsPrec :: Int -> TailStrategy -> ShowS
Show)
class KnownNat n => Schedulable f (n :: Nat) (a :: Type) where
vectorize :: VarOrRVar -> f n a -> IO (f n a)
unroll :: VarOrRVar -> f n a -> IO (f n a)
reorder
:: [VarOrRVar]
-> f n a
-> IO (f n a)
split :: TailStrategy -> VarOrRVar -> (VarOrRVar, VarOrRVar) -> Expr Int32 -> f n a -> IO (f n a)
fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)
serial :: VarOrRVar -> f n a -> IO (f n a)
parallel :: VarOrRVar -> f n a -> IO (f n a)
atomic
:: Bool
-> f n a
-> IO (f n a)
specialize :: Expr Bool -> f n a -> IO (Stage n a)
specializeFail :: Text -> f n a -> IO ()
gpuBlocks :: (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuThreads :: (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuLanes :: DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)
computeWith :: LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()
proveTransitivityOfLessThanEqual :: (KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) => Dict (k <= m)
proveTransitivityOfLessThanEqual :: forall (k :: Nat) (l :: Nat) (m :: Nat).
(KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) =>
Dict (k <= m)
proveTransitivityOfLessThanEqual = forall a b. a -> b
unsafeCoerce forall a b. (a -> b) -> a -> b
$ forall (a :: Constraint). a => Dict a
Dict @(1 <= 2)
instance KnownNat n => Schedulable Stage n a where
vectorize :: VarOrRVar -> Stage n a -> IO (Stage n a)
vectorize VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->vectorize(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
unroll :: VarOrRVar -> Stage n a -> IO (Stage n a)
unroll VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->unroll(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
reorder :: [VarOrRVar] -> Stage n a -> IO (Stage n a)
reorder [VarOrRVar]
args Stage n a
stage = do
forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar [VarOrRVar]
args forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
args' -> do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
[C.throwBlock| void {
handle_halide_exceptions([=]() {
$(Halide::Stage* stage')->reorder(
*$(const std::vector<Halide::VarOrRVar>* args'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Stage n a
-> IO (Stage n a)
split TailStrategy
tail VarOrRVar
old (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
factor Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
old forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
old' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
factor forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
factor' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->split(
*$(const Halide::VarOrRVar* old'),
*$(const Halide::VarOrRVar* outer'),
*$(const Halide::VarOrRVar* inner'),
*$(const Halide::Expr* factor'),
static_cast<Halide::TailStrategy>($(int t)));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
where
t :: CInt
t = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ TailStrategy
tail
fuse :: (VarOrRVar, VarOrRVar) -> VarOrRVar -> Stage n a -> IO (Stage n a)
fuse (VarOrRVar
outer, VarOrRVar
inner) VarOrRVar
fused Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
outer forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
outer' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
inner forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
inner' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
fused forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
fused' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->fuse(
*$(const Halide::VarOrRVar* outer'),
*$(const Halide::VarOrRVar* inner'),
*$(const Halide::VarOrRVar* fused'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
serial :: VarOrRVar -> Stage n a -> IO (Stage n a)
serial VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->serial(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
parallel :: VarOrRVar -> Stage n a -> IO (Stage n a)
parallel VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->parallel(*$(const Halide::VarOrRVar* var'));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
atomic :: Bool -> Stage n a -> IO (Stage n a)
atomic (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CBool
override) Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->atomic($(bool override));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
specialize :: Expr Bool -> Stage n a -> IO (Stage n a)
specialize Expr Bool
cond Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Bool
cond forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
cond' ->
forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Stage* {
return handle_halide_exceptions([=](){
return new Halide::Stage{$(Halide::Stage* stage')->specialize(
*$(const Halide::Expr* cond'))};
});
} |]
specializeFail :: Text -> Stage n a -> IO ()
specializeFail (Text -> ByteString
T.encodeUtf8 -> ByteString
s) Stage n a
stage =
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
[C.throwBlock| void {
return handle_halide_exceptions([=](){
$(Halide::Stage* stage')->specialize_fail(
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)});
});
} |]
gpuBlocks :: forall k. (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
gpuBlocks :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
gpuBlocks (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api :: C.CInt) IndexType k
vars Stage n a
stage =
case forall (k :: Nat) (l :: Nat) (m :: Nat).
(KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) =>
Dict (k <= m)
proveTransitivityOfLessThanEqual @k @3 @10 of
Dict (k <= 10)
Dict -> case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @k of
Sub Dict (IndexTypeProperties k)
(KnownNat k, k <= 10) => Dict (IndexTypeProperties k)
Dict ->
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType k
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
[C.throwBlock| void {
handle_halide_exceptions([=](){
auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
auto& stage = *$(Halide::Stage* stage');
auto const device = static_cast<Halide::DeviceAPI>($(int api));
switch (vars.size()) {
case 1: stage.gpu_blocks(vars.at(0), device);
break;
case 2: stage.gpu_blocks(vars.at(0), vars.at(1), device);
break;
case 3: stage.gpu_blocks(vars.at(0), vars.at(1), vars.at(2), device);
break;
default: throw std::runtime_error{"unexpected number of arguments in gpuBlocks"};
}
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
gpuThreads :: forall k. (KnownNat k, 1 <= k, k <= 3) => DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
gpuThreads :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Stage n a -> IO (Stage n a)
gpuThreads (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api :: C.CInt) IndexType k
vars Stage n a
stage =
case forall (k :: Nat) (l :: Nat) (m :: Nat).
(KnownNat k, KnownNat l, KnownNat m, k <= l, l <= m) =>
Dict (k <= m)
proveTransitivityOfLessThanEqual @k @3 @10 of
Dict (k <= 10)
Dict -> case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @k of
Sub Dict (IndexTypeProperties k)
(KnownNat k, k <= 10) => Dict (IndexTypeProperties k)
Dict -> do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType k
vars) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVarOrRVar)
vars' -> do
[C.throwBlock| void {
handle_halide_exceptions([=](){
auto const& vars = *$(const std::vector<Halide::VarOrRVar>* vars');
auto& stage = *$(Halide::Stage* stage');
auto const device = static_cast<Halide::DeviceAPI>($(int api));
switch (vars.size()) {
case 1: stage.gpu_threads(vars.at(0), device);
break;
case 2: stage.gpu_threads(vars.at(0), vars.at(1), device);
break;
case 3: stage.gpu_threads(vars.at(0), vars.at(1), vars.at(2), device);
break;
default: throw std::runtime_error{"unexpected number of arguments in gpuThreads"};
}
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
gpuLanes :: DeviceAPI -> VarOrRVar -> Stage n a -> IO (Stage n a)
gpuLanes (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
api) VarOrRVar
var Stage n a
stage = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
var' ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Stage* stage')->gpu_lanes(
*$(const Halide::VarOrRVar* var'),
static_cast<Halide::DeviceAPI>($(int api)));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Stage n a -> LoopLevel t -> IO ()
computeWith (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum -> CInt
align) Stage n a
stage LoopLevel t
level = do
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
[C.throwBlock| void {
handle_halide_exceptions([=]() {
$(Halide::Stage* stage')->compute_with(
*$(const Halide::LoopLevel* level'),
static_cast<Halide::LoopAlignStrategy>($(int align)));
});
} |]
viaStage1
:: KnownNat n
=> (a -> Stage n b -> IO (Stage n b))
-> a
-> Func t n b
-> IO (Func t n b)
viaStage1 :: forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 a -> Stage n b -> IO (Stage n b)
f a
a1 Func t n b
func = do
Stage n b
_ <- a -> Stage n b -> IO (Stage n b)
f a
a1 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func
viaStage2
:: (KnownNat n)
=> (a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1
-> a2
-> Func t n b
-> IO (Func t n b)
viaStage2 :: forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 Func t n b
func = do
Stage n b
_ <- a1 -> a2 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func
viaStage4
:: (KnownNat n)
=> (a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1
-> a2
-> a3
-> a4
-> Func t n b
-> IO (Func t n b)
viaStage4 :: forall (n :: Nat) a1 a2 a3 a4 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 Func t n b
func = do
Stage n b
_ <- a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b)
f a1
a1 a2
a2 a3
a3 a4
a4 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n b
func
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n b
func
instance KnownNat n => Schedulable (Func t) n a where
vectorize :: VarOrRVar -> Func t n a -> IO (Func t n a)
vectorize = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
vectorize
unroll :: VarOrRVar -> Func t n a -> IO (Func t n a)
unroll = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
unroll
reorder :: [VarOrRVar] -> Func t n a -> IO (Func t n a)
reorder = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
[VarOrRVar] -> f n a -> IO (f n a)
reorder
split :: TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> Func t n a
-> IO (Func t n a)
split = forall (n :: Nat) a1 a2 a3 a4 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> a3 -> a4 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> a3 -> a4 -> Func t n b -> IO (Func t n b)
viaStage4 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
TailStrategy
-> VarOrRVar
-> (VarOrRVar, VarOrRVar)
-> VarOrRVar
-> f n a
-> IO (f n a)
split
fuse :: (VarOrRVar, VarOrRVar)
-> VarOrRVar -> Func t n a -> IO (Func t n a)
fuse = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
(VarOrRVar, VarOrRVar) -> VarOrRVar -> f n a -> IO (f n a)
fuse
serial :: VarOrRVar -> Func t n a -> IO (Func t n a)
serial = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
serial
parallel :: VarOrRVar -> Func t n a -> IO (Func t n a)
parallel = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
VarOrRVar -> f n a -> IO (f n a)
parallel
atomic :: Bool -> Func t n a -> IO (Func t n a)
atomic = forall (n :: Nat) a b (t :: FuncTy).
KnownNat n =>
(a -> Stage n b -> IO (Stage n b))
-> a -> Func t n b -> IO (Func t n b)
viaStage1 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Bool -> f n a -> IO (f n a)
atomic
specialize :: Expr Bool -> Func t n a -> IO (Stage n a)
specialize Expr Bool
cond Func t n a
func = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Bool -> f n a -> IO (Stage n a)
specialize Expr Bool
cond
specializeFail :: Text -> Func t n a -> IO ()
specializeFail Text
msg Func t n a
func = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Text -> f n a -> IO ()
specializeFail Text
msg
gpuBlocks :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Func t n a -> IO (Func t n a)
gpuBlocks = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuBlocks
gpuThreads :: forall (k :: Nat).
(KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> Func t n a -> IO (Func t n a)
gpuThreads = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuThreads
gpuLanes :: DeviceAPI -> VarOrRVar -> Func t n a -> IO (Func t n a)
gpuLanes = forall (n :: Nat) a1 a2 b (t :: FuncTy).
KnownNat n =>
(a1 -> a2 -> Stage n b -> IO (Stage n b))
-> a1 -> a2 -> Func t n b -> IO (Func t n b)
viaStage2 forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
DeviceAPI -> VarOrRVar -> f n a -> IO (f n a)
gpuLanes
computeWith :: forall (t :: LoopLevelTy).
LoopAlignStrategy -> Func t n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Func t n a
f LoopLevel t
l = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
f forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Stage n a
f' -> forall (f :: Nat -> * -> *) (n :: Nat) a (t :: LoopLevelTy).
Schedulable f n a =>
LoopAlignStrategy -> f n a -> LoopLevel t -> IO ()
computeWith LoopAlignStrategy
a Stage n a
f' LoopLevel t
l
instance Enum TailStrategy where
fromEnum :: TailStrategy -> Int
fromEnum =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
TailStrategy
TailRoundUp -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |]
TailStrategy
TailGuardWithIf -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |]
TailStrategy
TailPredicate -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |]
TailStrategy
TailPredicateLoads -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |]
TailStrategy
TailPredicateStores -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |]
TailStrategy
TailShiftInwards -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |]
TailStrategy
TailAuto -> [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |]
toEnum :: Int -> TailStrategy
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::RoundUp) } |] = TailStrategy
TailRoundUp
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::GuardWithIf) } |] = TailStrategy
TailGuardWithIf
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Predicate) } |] = TailStrategy
TailPredicate
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateLoads) } |] = TailStrategy
TailPredicateLoads
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::PredicateStores) } |] = TailStrategy
TailPredicateStores
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::ShiftInwards) } |] = TailStrategy
TailShiftInwards
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::TailStrategy::Auto) } |] = TailStrategy
TailAuto
| Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid TailStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
estimate
:: KnownNat n
=> Expr Int32
-> Expr Int32
-> Expr Int32
-> Func t n a
-> IO ()
estimate :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
estimate VarOrRVar
var VarOrRVar
start VarOrRVar
extent Func t n a
func =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
start forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
[CU.exp| void {
$(Halide::Func* f)->set_estimate(
*$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]
bound
:: KnownNat n
=> Expr Int32
-> Expr Int32
-> Expr Int32
-> Func t n a
-> IO ()
bound :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
VarOrRVar -> VarOrRVar -> VarOrRVar -> Func t n a -> IO ()
bound VarOrRVar
var VarOrRVar
start VarOrRVar
extent Func t n a
func =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
i ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
start forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
minExpr ->
forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr VarOrRVar
extent forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
extentExpr ->
[CU.exp| void {
$(Halide::Func* f)->bound(
*$(Halide::Var* i), *$(Halide::Expr* minExpr), *$(Halide::Expr* extentExpr)) } |]
getArgs :: KnownNat n => Func t n a -> IO [Var]
getArgs :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO [VarOrRVar]
getArgs Func t n a
func =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' -> do
let allocate :: IO (Ptr (CxxVector CxxVar))
allocate =
[CU.exp| std::vector<Halide::Var>* {
new std::vector<Halide::Var>{$(const Halide::Func* func')->args()} } |]
destroy :: Ptr (CxxVector CxxVar) -> IO ()
destroy Ptr (CxxVector CxxVar)
v = [CU.exp| void { delete $(std::vector<Halide::Var>* v) } |]
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxVar))
allocate Ptr (CxxVector CxxVar) -> IO ()
destroy forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
v -> do
CSize
n <- [CU.exp| size_t { $(const std::vector<Halide::Var>* v)->size() } |]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CSize
0 .. CSize
n forall a. Num a => a -> a -> a
- CSize
1] forall a b. (a -> b) -> a -> b
$ \CSize
i ->
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (a :: k). ForeignPtr CxxVar -> Expr a
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxVar
ptr ->
[CU.exp| void {
new ($(Halide::Var* ptr)) Halide::Var{$(const std::vector<Halide::Var>* v)->at($(size_t i))} } |]
computeRoot :: KnownNat n => Func t n a -> IO (Func t n a)
computeRoot :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
computeRoot Func t n a
func = do
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
[C.throwBlock| void { handle_halide_exceptions([=](){ $(Halide::Func* f)->compute_root(); }); } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func
asUsedBy
:: (KnownNat n, KnownNat m)
=> Func t1 n a
-> Func 'FuncTy m b
-> IO (Func 'FuncTy n a)
asUsedBy :: forall (n :: Nat) (m :: Nat) (t1 :: FuncTy) a b.
(KnownNat n, KnownNat m) =>
Func t1 n a -> Func 'FuncTy m b -> IO (Func 'FuncTy n a)
asUsedBy Func t1 n a
g Func 'FuncTy m b
f =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t1 n a
g forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
gPtr -> forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy m b
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
new Halide::Func{$(Halide::Func* gPtr)->in(*$(Halide::Func* fPtr))} } |]
asUsed :: KnownNat n => Func t n a -> IO (Func 'FuncTy n a)
asUsed :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func 'FuncTy n a)
asUsed Func t n a
f =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
fPtr ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* { new Halide::Func{$(Halide::Func* fPtr)->in()} } |]
copyToDevice :: KnownNat n => DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
deviceApi Func t n a
func = do
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Func* f)->copy_to_device(static_cast<Halide::DeviceAPI>($(int api)));
});
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
func
where
api :: CInt
api = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum forall a b. (a -> b) -> a -> b
$ DeviceAPI
deviceApi
copyToHost :: KnownNat n => Func t n a -> IO (Func t n a)
copyToHost :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
copyToHost = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
DeviceAPI -> Func t n a -> IO (Func t n a)
copyToDevice DeviceAPI
DeviceHost
mkBufferParameter
:: forall n a. (KnownNat n, IsHalideType a) => Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter Maybe Text
maybeName = do
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
with (forall a (proxy :: * -> *). IsHalideType a => proxy a -> HalideType
halideTypeFor (forall {k} (t :: k). Proxy t
Proxy @a)) forall a b. (a -> b) -> a -> b
$ \Ptr HalideType
t -> do
let d :: CInt
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)
createWithoutName :: IO (Ptr CxxImageParam)
createWithoutName =
[CU.exp| Halide::ImageParam* {
new Halide::ImageParam{Halide::Type{*$(halide_type_t* t)}, $(int d)} } |]
deleter :: FunPtr (Ptr CxxImageParam -> IO ())
deleter = [C.funPtr| void deleteImageParam(Halide::ImageParam* p) { delete p; } |]
createWithName :: Text -> IO (Ptr CxxImageParam)
createWithName Text
name =
let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
in [CU.exp| Halide::ImageParam* {
new Halide::ImageParam{
Halide::Type{*$(halide_type_t* t)},
$(int d),
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}} } |]
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxImageParam -> IO ())
deleter forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (Ptr CxxImageParam)
createWithoutName Text -> IO (Ptr CxxImageParam)
createWithName Maybe Text
maybeName
getBufferParameter
:: forall n a
. (KnownNat n, IsHalideType a)
=> Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter Maybe Text
name IORef (Maybe (ForeignPtr CxxImageParam))
r =
forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just ForeignPtr CxxImageParam
fp -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp
Maybe (ForeignPtr CxxImageParam)
Nothing -> do
ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a Maybe Text
name
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxImageParam))
r (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
fp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ForeignPtr CxxImageParam
fp
withBufferParam
:: forall n a b
. (HasCallStack, KnownNat n, IsHalideType a)
=> Func 'ParamTy n (Expr a)
-> (Ptr CxxImageParam -> IO b)
-> IO b
withBufferParam :: forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n (Expr a) -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam (Param IORef (Maybe (ForeignPtr CxxImageParam))
r) Ptr CxxImageParam -> IO b
action =
forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr Ptr CxxImageParam -> IO b
action
withFunc :: KnownNat n => Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc :: forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f Ptr CxxFunc -> IO b
action = case Func t n a
f of
Func ForeignPtr CxxFunc
fp -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxFunc
fp Ptr CxxFunc -> IO b
action
p :: Func t n a
p@(Param IORef (Maybe (ForeignPtr CxxImageParam))
_) -> forall {k} (t :: FuncTy) (n :: Nat) (a :: k).
KnownNat n =>
Func t n (Expr a) -> IO (Func 'FuncTy n (Expr a))
forceFunc Func t n a
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Func ForeignPtr CxxFunc
fp) -> forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxFunc
fp Ptr CxxFunc -> IO b
action
withCxxFunc :: KnownNat n => Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc :: forall (n :: Nat) a b.
KnownNat n =>
Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc (Func ForeignPtr CxxFunc
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxFunc
fp
wrapCxxFunc :: Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc :: forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxFunc -> Func 'FuncTy n a
Func forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxFunc -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxFunc -> IO ())
deleter = [C.funPtr| void deleteFunc(Halide::Func *x) { delete x; } |]
forceFunc :: forall t n a. KnownNat n => Func t n (Expr a) -> IO (Func 'FuncTy n (Expr a))
forceFunc :: forall {k} (t :: FuncTy) (n :: Nat) (a :: k).
KnownNat n =>
Func t n (Expr a) -> IO (Func 'FuncTy n (Expr a))
forceFunc = \case
x :: Func t n (Expr a)
x@(Func ForeignPtr CxxFunc
_) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n (Expr a)
x
(Param IORef (Maybe (ForeignPtr CxxImageParam))
r) -> do
ForeignPtr CxxImageParam
fp <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a forall a. Maybe a
Nothing IORef (Maybe (ForeignPtr CxxImageParam))
r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
fp forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Func* {
new Halide::Func{static_cast<Halide::Func>(*$(Halide::ImageParam* p))} } |]
class IsFuncDefinition d where
definitionToExprList :: d -> [ForeignPtr CxxExpr]
exprListToDefinition :: [ForeignPtr CxxExpr] -> d
instance IsHalideType a => IsFuncDefinition (Expr a) where
definitionToExprList :: Expr a -> [ForeignPtr CxxExpr]
definitionToExprList = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr
exprListToDefinition :: [ForeignPtr CxxExpr] -> Expr a
exprListToDefinition [ForeignPtr CxxExpr
x1] = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x1 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x1)
exprListToDefinition [ForeignPtr CxxExpr]
_ = forall a. HasCallStack => String -> a
error String
"should never happen"
instance (IsHalideType a1, IsHalideType a2) => IsFuncDefinition (Expr a1, Expr a2) where
definitionToExprList :: (Expr a1, Expr a2) -> [ForeignPtr CxxExpr]
definitionToExprList (Expr a1
x1, Expr a2
x2) = [forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a1
x1, forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a2
x2]
exprListToDefinition :: [ForeignPtr CxxExpr] -> (Expr a1, Expr a2)
exprListToDefinition [ForeignPtr CxxExpr
x1, ForeignPtr CxxExpr
x2] = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x1 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a1)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x2 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a2)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x1, forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x2)
exprListToDefinition [ForeignPtr CxxExpr]
_ = forall a. HasCallStack => String -> a
error String
"should never happen"
instance (IsHalideType a1, IsHalideType a2, IsHalideType a3) => IsFuncDefinition (Expr a1, Expr a2, Expr a3) where
definitionToExprList :: (Expr a1, Expr a2, Expr a3) -> [ForeignPtr CxxExpr]
definitionToExprList (Expr a1
x1, Expr a2
x2, Expr a3
x3) = [forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a1
x1, forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a2
x2, forall a. IsHalideType a => Expr a -> ForeignPtr CxxExpr
exprToForeignPtr Expr a3
x3]
exprListToDefinition :: [ForeignPtr CxxExpr] -> (Expr a1, Expr a2, Expr a3)
exprListToDefinition [ForeignPtr CxxExpr
x1, ForeignPtr CxxExpr
x2, ForeignPtr CxxExpr
x3] = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x1 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a1)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x2 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a2)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxExpr
x3 (forall a t.
(HasCallStack, IsHalideType a, HasHalideType t) =>
t -> IO ()
checkType @a3)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x1, forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x2, forall {k} (a :: k). ForeignPtr CxxExpr -> Expr a
Expr ForeignPtr CxxExpr
x3)
exprListToDefinition [ForeignPtr CxxExpr]
_ = forall a. HasCallStack => String -> a
error String
"should never happen"
define
:: forall n d
. (HasIndexType n, IsFuncDefinition d)
=> Text
-> IndexType n
-> d
-> IO (Func 'FuncTy n d)
define :: forall (n :: Nat) d.
(HasIndexType n, IsFuncDefinition d) =>
Text -> IndexType n -> d -> IO (Func 'FuncTy n d)
define Text
name IndexType n
args d
definition =
case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @n of
Sub Dict (IndexTypeProperties n)
HasIndexType n => Dict (IndexTypeProperties n)
Dict -> forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall b. HasCallStack => VarOrRVar -> (Ptr CxxVar -> IO b) -> IO b
asVar (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType n
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxVar)
x -> do
let s :: ByteString
s = Text -> ByteString
T.encodeUtf8 Text
name
forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall d. IsFuncDefinition d => d -> [ForeignPtr CxxExpr]
definitionToExprList d
definition) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
v ->
forall (n :: Nat) a. Ptr CxxFunc -> IO (Func 'FuncTy n a)
wrapCxxFunc
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.block| Halide::Func* {
Halide::Func f{std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)}};
auto const& args = *$(const std::vector<Halide::Var>* x);
auto const& def = *$(const std::vector<Halide::Expr>* v);
if (def.size() == 1) {
f(args) = def.at(0);
}
else {
f(args) = Halide::Tuple{def};
}
return new Halide::Func{f};
} |]
update
:: forall n d
. (HasIndexType n, IsFuncDefinition d)
=> Func 'FuncTy n d
-> IndexType n
-> d
-> IO ()
update :: forall (n :: Nat) d.
(HasIndexType n, IsFuncDefinition d) =>
Func 'FuncTy n d -> IndexType n -> d -> IO ()
update Func 'FuncTy n d
func IndexType n
args d
definition =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n d
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @n of
Sub Dict (IndexTypeProperties n)
HasIndexType n => Dict (IndexTypeProperties n)
Dict -> forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType n
args) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
index ->
forall k t a.
HasCxxVector k =>
(t -> (Ptr k -> IO a) -> IO a)
-> [t] -> (Ptr (CxxVector k) -> IO a) -> IO a
withMany forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr (forall d. IsFuncDefinition d => d -> [ForeignPtr CxxExpr]
definitionToExprList d
definition) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
value ->
[C.throwBlock| void {
handle_halide_exceptions([=](){
auto& f = *$(Halide::Func* f);
auto const& index = *$(const std::vector<Halide::Expr>* index);
auto const& value = *$(const std::vector<Halide::Expr>* value);
if (value.size() == 1) {
f(index) = value.at(0);
}
else {
f(index) = Halide::Tuple{value};
}
});
} |]
infix 9 !
withExprIndices :: forall n a. HasIndexType n => IndexType n -> (Ptr (CxxVector CxxExpr) -> IO a) -> IO a
withExprIndices :: forall (n :: Nat) a.
HasIndexType n =>
IndexType n -> (Ptr (CxxVector CxxExpr) -> IO a) -> IO a
withExprIndices IndexType n
indices Ptr (CxxVector CxxExpr) -> IO a
action =
case forall (n :: Nat). (KnownNat n, n <= 10) :- IndexTypeProperties n
proveIndexTypeProperties @n of
Sub Dict (IndexTypeProperties n)
HasIndexType n => Dict (IndexTypeProperties n)
Dict -> forall (c :: * -> Constraint) k (ts :: [*]) a.
(All c ts, HasCxxVector k) =>
(forall t b. c t => t -> (Ptr k -> IO b) -> IO b)
-> Arguments ts -> (Ptr (CxxVector k) -> IO a) -> IO a
asVectorOf @((~) (Expr Int32)) forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr (forall (a :: [*]) t. IsTuple a t => t -> Arguments a
fromTuple IndexType n
indices) forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
Ptr (CxxVector CxxExpr) -> IO a
action Ptr (CxxVector CxxExpr)
x
indexFunc :: forall n a t. HasIndexType n => Func t n a -> IndexType n -> IO [ForeignPtr CxxExpr]
indexFunc :: forall (n :: Nat) a (t :: FuncTy).
HasIndexType n =>
Func t n a -> IndexType n -> IO [ForeignPtr CxxExpr]
indexFunc Func t n a
func IndexType n
indices = forall (n :: Nat) a.
HasIndexType n =>
IndexType n -> (Ptr (CxxVector CxxExpr) -> IO a) -> IO a
withExprIndices IndexType n
indices forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
x ->
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> do
let allocate :: IO (Ptr (CxxVector CxxExpr))
allocate =
[CU.block| std::vector<Halide::Expr>* {
Halide::FuncRef ref = $(Halide::Func* f)->operator()(*$(std::vector<Halide::Expr>* x));
std::vector<Halide::Expr> v;
if (ref.size() == 1) {
v.push_back(static_cast<Halide::Expr>(ref));
}
else {
for (auto i = size_t{0}; i < ref.size(); ++i) {
v.push_back(ref[i]);
}
}
return new std::vector<Halide::Expr>{std::move(v)};
} |]
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxExpr))
allocate forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxExpr)
v -> do
CSize
size <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. HasCxxVector a => Ptr (CxxVector a) -> IO Int
cxxVectorSize Ptr (CxxVector CxxExpr)
v
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [CSize
0 .. CSize
size forall a. Num a => a -> a -> a
- CSize
1] forall a b. (a -> b) -> a -> b
$ \CSize
i ->
forall a.
CxxConstructible a =>
(Ptr a -> IO ()) -> IO (ForeignPtr a)
cxxConstruct forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
$(const std::vector<Halide::Expr>* v)->at($(size_t i))} } |]
(!) :: (HasIndexType n, IsFuncDefinition a) => Func t n a -> IndexType n -> a
! :: forall (n :: Nat) a (t :: FuncTy).
(HasIndexType n, IsFuncDefinition a) =>
Func t n a -> IndexType n -> a
(!) Func t n a
func IndexType n
args = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a (t :: FuncTy).
HasIndexType n =>
Func t n a -> IndexType n -> IO [ForeignPtr CxxExpr]
indexFunc Func t n a
func IndexType n
args forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> forall d. IsFuncDefinition d => [ForeignPtr CxxExpr] -> d
exprListToDefinition
dim
:: forall n a
. (HasCallStack, KnownNat n)
=> Int
-> Func 'ParamTy n (Expr a)
-> IO Dimension
dim :: forall {k} (n :: Nat) (a :: k).
(HasCallStack, KnownNat n) =>
Int -> Func 'ParamTy n (Expr a) -> IO Dimension
dim Int
k func :: Func 'ParamTy n (Expr a)
func@(Param IORef (Maybe (ForeignPtr CxxImageParam))
_)
| Int
0 forall a. Ord a => a -> a -> Bool
<= Int
k Bool -> Bool -> Bool
&& Int
k forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n)) =
let n :: CInt
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
in forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n (Expr a) -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func 'ParamTy n (Expr a)
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
f ->
Ptr CxxDimension -> IO Dimension
wrapCxxDimension
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::Internal::Dimension* {
new Halide::Internal::Dimension{$(Halide::ImageParam* f)->dim($(int n))} } |]
| Bool
otherwise =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$
String
"invalid dimension index: "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
forall a. Semigroup a => a -> a -> a
<> String
"; Func is "
forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show (forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy @n))
forall a. Semigroup a => a -> a -> a
<> String
"-dimensional"
prettyLoopNest :: KnownNat n => Func t n r -> IO Text
prettyLoopNest :: forall (n :: Nat) (t :: FuncTy) r.
KnownNat n =>
Func t n r -> IO Text
prettyLoopNest Func t n r
func = forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n r
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| std::string* {
return handle_halide_exceptions([=]() {
return new std::string{Halide::Internal::print_loop_nest(
std::vector<Halide::Internal::Function>{$(Halide::Func* f)->function()})};
});
} |]
realize
:: forall n a t b
. (KnownNat n, IsHalideType a)
=> Func t n (Expr a)
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realize :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n (Expr a)
-> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
realize = forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Target
-> Func t n (Expr a)
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realizeOnTarget Target
hostTarget
realizeOnTarget
:: forall n a t b
. (KnownNat n, IsHalideType a)
=> Target
-> Func t n (Expr a)
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realizeOnTarget :: forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Target
-> Func t n (Expr a)
-> [Int]
-> (Ptr (HalideBuffer n a) -> IO b)
-> IO b
realizeOnTarget Target
target Func t n (Expr a)
func [Int]
shape Ptr (HalideBuffer n a) -> IO b
action =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n (Expr a)
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Target -> [Int] -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
allocaBuffer Target
target [Int]
shape forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
buf -> do
let raw :: Ptr RawHalideBuffer
raw = forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
buf
[C.throwBlock| void {
handle_halide_exceptions([=](){
$(Halide::Func* func')->realize($(halide_buffer_t* raw), *$(const Halide::Target* target'));
});
} |]
Ptr (HalideBuffer n a) -> IO b
action Ptr (HalideBuffer n a)
buf
buffer :: forall n a. (KnownNat n, IsHalideType a) => Text -> Func 'ParamTy n (Expr a) -> Func 'ParamTy n (Expr a)
buffer :: forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Text -> Func 'ParamTy n (Expr a) -> Func 'ParamTy n (Expr a)
buffer Text
name p :: Func 'ParamTy n (Expr a)
p@(Param IORef (Maybe (ForeignPtr CxxImageParam))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
ForeignPtr CxxImageParam
_ <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text
-> IORef (Maybe (ForeignPtr CxxImageParam))
-> IO (ForeignPtr CxxImageParam)
getBufferParameter @n @a (forall a. a -> Maybe a
Just Text
name) IORef (Maybe (ForeignPtr CxxImageParam))
r
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'ParamTy n (Expr a)
p
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar :: forall a. IsHalideType a => Text -> Expr a -> Expr a
scalar Text
name (ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
forall a. IORef a -> IO a
readIORef IORef (Maybe (ForeignPtr CxxParameter))
r forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just ForeignPtr CxxParameter
_ -> forall a. HasCallStack => String -> a
error String
"the name of this Expr has already been set"
Maybe (ForeignPtr CxxParameter)
Nothing -> do
ForeignPtr CxxParameter
fp <- forall a.
IsHalideType a =>
Maybe Text -> IO (ForeignPtr CxxParameter)
mkScalarParameter @a (forall a. a -> Maybe a
Just Text
name)
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (ForeignPtr CxxParameter))
r (forall a. a -> Maybe a
Just ForeignPtr CxxParameter
fp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam IORef (Maybe (ForeignPtr CxxParameter))
r)
scalar Text
_ Expr a
_ = forall a. HasCallStack => String -> a
error String
"cannot set the name of an expression that is not a parameter"
wrapCxxStage :: Ptr CxxStage -> IO (Stage n a)
wrapCxxStage :: forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (n :: Nat) a. ForeignPtr CxxStage -> Stage n a
Stage forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxStage -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxStage -> IO ())
deleter = [C.funPtr| void deleteStage(Halide::Stage* p) { delete p; } |]
withCxxStage :: Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage :: forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage (Stage ForeignPtr CxxStage
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxStage
fp
getStage :: KnownNat n => Func t n a -> IO (Stage n a)
getStage :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Stage n a)
getStage Func t n a
func =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
[CU.exp| Halide::Stage* { new Halide::Stage{static_cast<Halide::Stage>(*$(Halide::Func* func'))} } |]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
hasUpdateDefinitions :: KnownNat n => Func t n a -> IO Bool
hasUpdateDefinitions :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO Bool
hasUpdateDefinitions Func t n a
func =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Func* func')->has_update_definition() } |]
getUpdateStage :: KnownNat n => Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage :: forall (n :: Nat) a.
KnownNat n =>
Int -> Func 'FuncTy n a -> IO (Stage n a)
getUpdateStage Int
k Func 'FuncTy n a
func =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
let k' :: CInt
k' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k
in [CU.exp| Halide::Stage* { new Halide::Stage{$(Halide::Func* func')->update($(int k'))} } |]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (n :: Nat) a. Ptr CxxStage -> IO (Stage n a)
wrapCxxStage
getLoopLevelAtStage
:: KnownNat n
=> Func t n a
-> Expr Int32
-> Int
-> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
func VarOrRVar
var Int
stageIndex =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f -> forall b.
HasCallStack =>
VarOrRVar -> (Ptr CxxVarOrRVar -> IO b) -> IO b
asVarOrRVar VarOrRVar
var forall a b. (a -> b) -> a -> b
$ \Ptr CxxVarOrRVar
i -> do
(SomeLoopLevel LoopLevel t
level) <-
Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::LoopLevel* {
return handle_halide_exceptions([=](){
return new Halide::LoopLevel{*$(const Halide::Func* f),
*$(const Halide::VarOrRVar* i),
$(int k)};
});
} |]
case LoopLevel t
level of
LoopLevel ForeignPtr CxxLoopLevel
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopLevel t
level
LoopLevel t
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"getLoopLevelAtStage: got " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show LoopLevel t
level forall a. Semigroup a => a -> a -> a
<> String
", but expected a LoopLevel 'LockedTy"
where
k :: CInt
k = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
stageIndex
getLoopLevel :: KnownNat n => Func t n a -> Expr Int32 -> IO (LoopLevel 'LockedTy)
getLoopLevel :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> VarOrRVar -> IO (LoopLevel 'LockedTy)
getLoopLevel Func t n a
f VarOrRVar
i = forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> VarOrRVar -> Int -> IO (LoopLevel 'LockedTy)
getLoopLevelAtStage Func t n a
f VarOrRVar
i (-Int
1)
storeAt :: KnownNat n => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
KnownNat n =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
storeAt Func 'FuncTy n a
func LoopLevel t
level = do
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
[CU.exp| void { $(Halide::Func* f)->store_at(*$(const Halide::LoopLevel* l)) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func
computeAt :: KnownNat n => Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt :: forall (n :: Nat) a (t :: LoopLevelTy).
KnownNat n =>
Func 'FuncTy n a -> LoopLevel t -> IO (Func 'FuncTy n a)
computeAt Func 'FuncTy n a
func LoopLevel t
level = do
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l ->
[CU.exp| void { $(Halide::Func* f)->compute_at(*$(const Halide::LoopLevel* l)) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func
asBufferParam
:: forall n a t b
. IsHalideBuffer t n a
=> t
-> (Func 'ParamTy n (Expr a) -> IO b)
-> IO b
asBufferParam :: forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Func 'ParamTy n (Expr a) -> IO b) -> IO b
asBufferParam t
arr Func 'ParamTy n (Expr a) -> IO b
action =
forall (n :: Nat) a t b.
IsHalideBuffer t n a =>
t -> (Ptr (HalideBuffer n a) -> IO b) -> IO b
withHalideBuffer @n @a t
arr forall a b. (a -> b) -> a -> b
$ \Ptr (HalideBuffer n a)
arr' -> do
ForeignPtr CxxImageParam
param <- forall (n :: Nat) a.
(KnownNat n, IsHalideType a) =>
Maybe Text -> IO (ForeignPtr CxxImageParam)
mkBufferParameter @n @a forall a. Maybe a
Nothing
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxImageParam
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
param' ->
let buf :: Ptr RawHalideBuffer
buf = (forall a b. Ptr a -> Ptr b
castPtr Ptr (HalideBuffer n a)
arr' :: Ptr RawHalideBuffer)
in [CU.block| void {
$(Halide::ImageParam* param')->set(Halide::Buffer<>{*$(const halide_buffer_t* buf)});
} |]
Func 'ParamTy n (Expr a) -> IO b
action forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a (n :: Nat).
IsHalideType a =>
IORef (Maybe (ForeignPtr CxxImageParam))
-> Func 'ParamTy n (Expr a)
Param forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. a -> IO (IORef a)
newIORef (forall a. a -> Maybe a
Just ForeignPtr CxxImageParam
param)