{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Language.Halide.Schedule
( Dim (..)
, DimType (..)
, ForType (..)
, SplitContents (..)
, FuseContents (..)
, Split (..)
, Bound (..)
, StorageDim (..)
, FusedPair (..)
, FuseLoopLevel (..)
, StageSchedule (..)
, ReductionVariable (..)
, PrefetchDirective (..)
, getStageSchedule
, AutoScheduler (..)
, loadAutoScheduler
, applyAutoScheduler
, getHalideLibraryPath
, applySplits
, applyDims
, applySchedule
)
where
import Control.Monad (void)
import Data.Text (Text, pack, unpack)
import Data.Text qualified as T
import Data.Text.Encoding (encodeUtf8)
import Foreign.C.Types (CInt (..))
import Foreign.ForeignPtr
import Foreign.Marshal (allocaArray, peekArray, toBool)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Func
import Language.Halide.LoopLevel
import Language.Halide.Target
import Language.Halide.Type
import Language.Halide.Utils
import System.FilePath (takeDirectory)
import Prelude hiding (tail)
#if USE_DLOPEN
import qualified System.Posix.DynamicLinker as DL
loadLibrary :: Text -> IO ()
loadLibrary :: Text -> IO ()
loadLibrary Text
path = do
DL
_ <- FilePath -> [RTLDFlags] -> IO DL
DL.dlopen (Text -> FilePath
unpack Text
path) [RTLDFlags
DL.RTLD_LAZY]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#else
import qualified System.Win32.DLL as Win32
loadLibrary :: Text -> IO ()
loadLibrary path = do
_ <- Win32.loadLibrary (unpack path)
pure ()
#endif
data DimType = DimPureVar | DimPureRVar | DimImpureRVar
deriving stock (Int -> DimType -> ShowS
[DimType] -> ShowS
DimType -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [DimType] -> ShowS
$cshowList :: [DimType] -> ShowS
show :: DimType -> FilePath
$cshow :: DimType -> FilePath
showsPrec :: Int -> DimType -> ShowS
$cshowsPrec :: Int -> DimType -> ShowS
Show, DimType -> DimType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DimType -> DimType -> Bool
$c/= :: DimType -> DimType -> Bool
== :: DimType -> DimType -> Bool
$c== :: DimType -> DimType -> Bool
Eq)
data ForType
= ForSerial
| ForParallel
| ForVectorized
| ForUnrolled
| ForExtern
| ForGPUBlock
| ForGPUThread
| ForGPULane
deriving stock (Int -> ForType -> ShowS
[ForType] -> ShowS
ForType -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [ForType] -> ShowS
$cshowList :: [ForType] -> ShowS
show :: ForType -> FilePath
$cshow :: ForType -> FilePath
showsPrec :: Int -> ForType -> ShowS
$cshowsPrec :: Int -> ForType -> ShowS
Show, ForType -> ForType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ForType -> ForType -> Bool
$c/= :: ForType -> ForType -> Bool
== :: ForType -> ForType -> Bool
$c== :: ForType -> ForType -> Bool
Eq)
data Dim = Dim {Dim -> Text
var :: !Text, Dim -> ForType
forType :: !ForType, Dim -> DeviceAPI
deviceApi :: !DeviceAPI, Dim -> DimType
dimType :: !DimType}
deriving stock (Int -> Dim -> ShowS
[Dim] -> ShowS
Dim -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Dim] -> ShowS
$cshowList :: [Dim] -> ShowS
show :: Dim -> FilePath
$cshow :: Dim -> FilePath
showsPrec :: Int -> Dim -> ShowS
$cshowsPrec :: Int -> Dim -> ShowS
Show, Dim -> Dim -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dim -> Dim -> Bool
$c/= :: Dim -> Dim -> Bool
== :: Dim -> Dim -> Bool
$c== :: Dim -> Dim -> Bool
Eq)
data FuseContents = FuseContents
{ FuseContents -> Text
fuseOuter :: !Text
, FuseContents -> Text
fuseInner :: !Text
, FuseContents -> Text
fuseNew :: !Text
}
deriving stock (Int -> FuseContents -> ShowS
[FuseContents] -> ShowS
FuseContents -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [FuseContents] -> ShowS
$cshowList :: [FuseContents] -> ShowS
show :: FuseContents -> FilePath
$cshow :: FuseContents -> FilePath
showsPrec :: Int -> FuseContents -> ShowS
$cshowsPrec :: Int -> FuseContents -> ShowS
Show, FuseContents -> FuseContents -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuseContents -> FuseContents -> Bool
$c/= :: FuseContents -> FuseContents -> Bool
== :: FuseContents -> FuseContents -> Bool
$c== :: FuseContents -> FuseContents -> Bool
Eq)
data SplitContents = SplitContents
{ SplitContents -> Text
splitOld :: !Text
, SplitContents -> Text
splitOuter :: !Text
, SplitContents -> Text
splitInner :: !Text
, SplitContents -> Expr Int32
splitFactor :: !(Expr Int32)
, SplitContents -> Bool
splitExact :: !Bool
, SplitContents -> TailStrategy
splitTail :: !TailStrategy
}
deriving stock (Int -> SplitContents -> ShowS
[SplitContents] -> ShowS
SplitContents -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [SplitContents] -> ShowS
$cshowList :: [SplitContents] -> ShowS
show :: SplitContents -> FilePath
$cshow :: SplitContents -> FilePath
showsPrec :: Int -> SplitContents -> ShowS
$cshowsPrec :: Int -> SplitContents -> ShowS
Show)
data Split
= SplitVar !SplitContents
| FuseVars !FuseContents
deriving stock (Int -> Split -> ShowS
[Split] -> ShowS
Split -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Split] -> ShowS
$cshowList :: [Split] -> ShowS
show :: Split -> FilePath
$cshow :: Split -> FilePath
showsPrec :: Int -> Split -> ShowS
$cshowsPrec :: Int -> Split -> ShowS
Show)
data Bound = Bound
{ Bound -> Text
boundVar :: !Text
, Bound -> Maybe (Expr Int32)
boundMin :: !(Maybe (Expr Int32))
, Bound -> Expr Int32
boundExtent :: !(Expr Int32)
, Bound -> Maybe (Expr Int32)
boundModulus :: !(Maybe (Expr Int32))
, Bound -> Maybe (Expr Int32)
boundRemainder :: !(Maybe (Expr Int32))
}
deriving stock (Int -> Bound -> ShowS
[Bound] -> ShowS
Bound -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Bound] -> ShowS
$cshowList :: [Bound] -> ShowS
show :: Bound -> FilePath
$cshow :: Bound -> FilePath
showsPrec :: Int -> Bound -> ShowS
$cshowsPrec :: Int -> Bound -> ShowS
Show)
data StorageDim = StorageDim
{ StorageDim -> Text
storageVar :: !Text
, StorageDim -> Maybe (Expr Int32)
storageAlignment :: !(Maybe (Expr Int32))
, StorageDim -> Maybe (Expr Int32)
storageBound :: !(Maybe (Expr Int32))
, StorageDim -> Maybe (Expr Int32, Bool)
storageFold :: !(Maybe (Expr Int32, Bool))
}
deriving stock (Int -> StorageDim -> ShowS
[StorageDim] -> ShowS
StorageDim -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [StorageDim] -> ShowS
$cshowList :: [StorageDim] -> ShowS
show :: StorageDim -> FilePath
$cshow :: StorageDim -> FilePath
showsPrec :: Int -> StorageDim -> ShowS
$cshowsPrec :: Int -> StorageDim -> ShowS
Show)
data FusedPair = FusedPair !Text !(Text, Int) !(Text, Int)
deriving stock (Int -> FusedPair -> ShowS
[FusedPair] -> ShowS
FusedPair -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [FusedPair] -> ShowS
$cshowList :: [FusedPair] -> ShowS
show :: FusedPair -> FilePath
$cshow :: FusedPair -> FilePath
showsPrec :: Int -> FusedPair -> ShowS
$cshowsPrec :: Int -> FusedPair -> ShowS
Show, FusedPair -> FusedPair -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FusedPair -> FusedPair -> Bool
$c/= :: FusedPair -> FusedPair -> Bool
== :: FusedPair -> FusedPair -> Bool
$c== :: FusedPair -> FusedPair -> Bool
Eq)
data FuseLoopLevel = FuseLoopLevel !SomeLoopLevel
deriving stock (Int -> FuseLoopLevel -> ShowS
[FuseLoopLevel] -> ShowS
FuseLoopLevel -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [FuseLoopLevel] -> ShowS
$cshowList :: [FuseLoopLevel] -> ShowS
show :: FuseLoopLevel -> FilePath
$cshow :: FuseLoopLevel -> FilePath
showsPrec :: Int -> FuseLoopLevel -> ShowS
$cshowsPrec :: Int -> FuseLoopLevel -> ShowS
Show, FuseLoopLevel -> FuseLoopLevel -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuseLoopLevel -> FuseLoopLevel -> Bool
$c/= :: FuseLoopLevel -> FuseLoopLevel -> Bool
== :: FuseLoopLevel -> FuseLoopLevel -> Bool
$c== :: FuseLoopLevel -> FuseLoopLevel -> Bool
Eq)
data ReductionVariable = ReductionVariable {ReductionVariable -> Text
varName :: !Text, ReductionVariable -> Expr Int32
minExpr :: !(Expr Int32), ReductionVariable -> Expr Int32
extentExpr :: !(Expr Int32)}
deriving stock (Int -> ReductionVariable -> ShowS
[ReductionVariable] -> ShowS
ReductionVariable -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [ReductionVariable] -> ShowS
$cshowList :: [ReductionVariable] -> ShowS
show :: ReductionVariable -> FilePath
$cshow :: ReductionVariable -> FilePath
showsPrec :: Int -> ReductionVariable -> ShowS
$cshowsPrec :: Int -> ReductionVariable -> ShowS
Show)
data PrefetchBoundStrategy
= PrefetchClamp
| PrefetchGuardWithIf
| PrefetchNonFaulting
deriving stock (Int -> PrefetchBoundStrategy -> ShowS
[PrefetchBoundStrategy] -> ShowS
PrefetchBoundStrategy -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [PrefetchBoundStrategy] -> ShowS
$cshowList :: [PrefetchBoundStrategy] -> ShowS
show :: PrefetchBoundStrategy -> FilePath
$cshow :: PrefetchBoundStrategy -> FilePath
showsPrec :: Int -> PrefetchBoundStrategy -> ShowS
$cshowsPrec :: Int -> PrefetchBoundStrategy -> ShowS
Show, PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
$c/= :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
== :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
$c== :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
Eq)
data PrefetchDirective = PrefetchDirective
{ PrefetchDirective -> Text
prefetchFunc :: !Text
, PrefetchDirective -> Text
prefetchAt :: !Text
, PrefetchDirective -> Text
prefetchFrom :: !Text
, PrefetchDirective -> Expr Int32
prefetchOffset :: !(Expr Int32)
, PrefetchDirective -> PrefetchBoundStrategy
prefetchStrategy :: !PrefetchBoundStrategy
, PrefetchDirective -> Maybe (ForeignPtr CxxParameter)
prefetchParameter :: !(Maybe (ForeignPtr CxxParameter))
}
deriving stock (Int -> PrefetchDirective -> ShowS
[PrefetchDirective] -> ShowS
PrefetchDirective -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [PrefetchDirective] -> ShowS
$cshowList :: [PrefetchDirective] -> ShowS
show :: PrefetchDirective -> FilePath
$cshow :: PrefetchDirective -> FilePath
showsPrec :: Int -> PrefetchDirective -> ShowS
$cshowsPrec :: Int -> PrefetchDirective -> ShowS
Show)
data StageSchedule = StageSchedule
{ StageSchedule -> [ReductionVariable]
rvars :: ![ReductionVariable]
, StageSchedule -> [Split]
splits :: ![Split]
, StageSchedule -> [Dim]
dims :: ![Dim]
, StageSchedule -> [PrefetchDirective]
prefetches :: ![PrefetchDirective]
, StageSchedule -> FuseLoopLevel
fuseLevel :: !FuseLoopLevel
, StageSchedule -> [FusedPair]
fusedPairs :: ![FusedPair]
, StageSchedule -> Bool
allowRaceConditions :: !Bool
, StageSchedule -> Bool
atomic :: !Bool
, StageSchedule -> Bool
overrideAtomicAssociativityTest :: !Bool
}
deriving stock (Int -> StageSchedule -> ShowS
[StageSchedule] -> ShowS
StageSchedule -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [StageSchedule] -> ShowS
$cshowList :: [StageSchedule] -> ShowS
show :: StageSchedule -> FilePath
$cshow :: StageSchedule -> FilePath
showsPrec :: Int -> StageSchedule -> ShowS
$cshowsPrec :: Int -> StageSchedule -> ShowS
Show)
importHalide
instanceHasCxxVector "Halide::Internal::Dim"
instanceHasCxxVector "Halide::Internal::Split"
instanceHasCxxVector "Halide::Internal::FusedPair"
instanceHasCxxVector "Halide::Internal::ReductionVariable"
instance Enum ForType where
toEnum :: Int -> ForType
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Serial) } |] =
ForType
ForSerial
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Parallel) } |] =
ForType
ForParallel
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Vectorized) } |] =
ForType
ForVectorized
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Unrolled) } |] =
ForType
ForUnrolled
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Extern) } |] =
ForType
ForExtern
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::GPUBlock) } |] =
ForType
ForGPUBlock
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::GPUThread) } |] =
ForType
ForGPUThread
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::GPULane) } |] =
ForType
ForGPULane
| Bool
otherwise = forall a. HasCallStack => FilePath -> a
error forall a b. (a -> b) -> a -> b
$ FilePath
"invalid ForType: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
k
fromEnum :: ForType -> Int
fromEnum = forall a. HasCallStack => FilePath -> a
error FilePath
"Enum instance for ForType does not implement fromEnum"
instance Enum DimType where
toEnum :: Int -> DimType
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::DimType::PureVar) } |] =
DimType
DimPureVar
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::DimType::PureRVar) } |] =
DimType
DimPureRVar
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::DimType::ImpureRVar) } |] =
DimType
DimImpureRVar
| Bool
otherwise = forall a. HasCallStack => FilePath -> a
error forall a b. (a -> b) -> a -> b
$ FilePath
"invalid DimType: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
k
fromEnum :: DimType -> Int
fromEnum = forall a. HasCallStack => FilePath -> a
error FilePath
"Enum instance for DimType does not implement fromEnum"
instance Enum PrefetchBoundStrategy where
toEnum :: Int -> PrefetchBoundStrategy
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::PrefetchBoundStrategy::Clamp) } |] =
PrefetchBoundStrategy
PrefetchClamp
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::PrefetchBoundStrategy::GuardWithIf) } |] =
PrefetchBoundStrategy
PrefetchGuardWithIf
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::PrefetchBoundStrategy::NonFaulting) } |] =
PrefetchBoundStrategy
PrefetchNonFaulting
| Bool
otherwise = forall a. HasCallStack => FilePath -> a
error forall a b. (a -> b) -> a -> b
$ FilePath
"invalid PrefetchBoundStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
k
fromEnum :: PrefetchBoundStrategy -> Int
fromEnum = forall a. HasCallStack => FilePath -> a
error FilePath
"Enum instance for ForType does not implement fromEnum"
instance Storable FusedPair where
sizeOf :: FusedPair -> Int
sizeOf FusedPair
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::FusedPair) } |]
alignment :: FusedPair -> Int
alignment FusedPair
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::FusedPair) } |]
peek :: Ptr FusedPair -> IO FusedPair
peek Ptr FusedPair
p = do
Text
func1 <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::FusedPair* p)->func_1 } |]
Text
func2 <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::FusedPair* p)->func_2 } |]
Int
stage1 <-
forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| size_t { $(const Halide::Internal::FusedPair* p)->stage_1 } |]
Int
stage2 <-
forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| size_t { $(const Halide::Internal::FusedPair* p)->stage_2 } |]
Text
varName <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::FusedPair* p)->var_name } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> (Text, Int) -> (Text, Int) -> FusedPair
FusedPair Text
varName (Text
func1, Int
stage1) (Text
func2, Int
stage2)
poke :: Ptr FusedPair -> FusedPair -> IO ()
poke Ptr FusedPair
_ FusedPair
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance of FusedPair does not implement poke"
instance Storable ReductionVariable where
sizeOf :: ReductionVariable -> Int
sizeOf ReductionVariable
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::ReductionVariable) } |]
alignment :: ReductionVariable -> Int
alignment ReductionVariable
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::ReductionVariable) } |]
peek :: Ptr ReductionVariable -> IO ReductionVariable
peek Ptr ReductionVariable
p = do
Text
varName <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::ReductionVariable* p)->var } |]
Expr Int32
minExpr <-
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
$(const Halide::Internal::ReductionVariable* p)->min} } |]
Expr Int32
extentExpr <-
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
$(const Halide::Internal::ReductionVariable* p)->extent} } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> Expr Int32 -> Expr Int32 -> ReductionVariable
ReductionVariable Text
varName Expr Int32
minExpr Expr Int32
extentExpr
poke :: Ptr ReductionVariable -> ReductionVariable -> IO ()
poke Ptr ReductionVariable
_ ReductionVariable
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance of ReductionVariable does not implement poke"
instance Storable PrefetchDirective where
sizeOf :: PrefetchDirective -> Int
sizeOf PrefetchDirective
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::PrefetchDirective) } |]
alignment :: PrefetchDirective -> Int
alignment PrefetchDirective
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::PrefetchDirective) } |]
peek :: Ptr PrefetchDirective -> IO PrefetchDirective
peek Ptr PrefetchDirective
p = do
Text
funcName' <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::PrefetchDirective* p)->name } |]
Text
atVar' <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::PrefetchDirective* p)->at } |]
Text
fromVar' <-
Ptr CxxString -> IO Text
peekCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::PrefetchDirective* p)->from } |]
Expr Int32
offset' <-
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
$(const Halide::Internal::PrefetchDirective* p)->offset} } |]
PrefetchBoundStrategy
strategy' <-
forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(const Halide::Internal::PrefetchDirective* p)->strategy) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text
-> Text
-> Text
-> Expr Int32
-> PrefetchBoundStrategy
-> Maybe (ForeignPtr CxxParameter)
-> PrefetchDirective
PrefetchDirective Text
funcName' Text
atVar' Text
fromVar' Expr Int32
offset' PrefetchBoundStrategy
strategy' forall a. Maybe a
Nothing
poke :: Ptr PrefetchDirective -> PrefetchDirective -> IO ()
poke Ptr PrefetchDirective
_ PrefetchDirective
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance for PrefetchDirective does not implement poke"
getReductionVariables :: Ptr CxxStageSchedule -> IO [ReductionVariable]
getReductionVariables :: Ptr CxxStageSchedule -> IO [ReductionVariable]
getReductionVariables Ptr CxxStageSchedule
schedule =
forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::ReductionVariable>* {
&$(const Halide::Internal::StageSchedule* schedule)->rvars() } |]
getSplits :: Ptr CxxStageSchedule -> IO [Split]
getSplits :: Ptr CxxStageSchedule -> IO [Split]
getSplits Ptr CxxStageSchedule
schedule =
forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::Split>* {
&$(const Halide::Internal::StageSchedule* schedule)->splits() } |]
getDims :: Ptr CxxStageSchedule -> IO [Dim]
getDims :: Ptr CxxStageSchedule -> IO [Dim]
getDims Ptr CxxStageSchedule
schedule =
forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::Dim>* {
&$(const Halide::Internal::StageSchedule* schedule)->dims() } |]
getFuseLoopLevel :: Ptr CxxStageSchedule -> IO FuseLoopLevel
getFuseLoopLevel :: Ptr CxxStageSchedule -> IO FuseLoopLevel
getFuseLoopLevel Ptr CxxStageSchedule
schedule =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SomeLoopLevel -> FuseLoopLevel
FuseLoopLevel forall a b. (a -> b) -> a -> b
$
Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::LoopLevel* {
new Halide::LoopLevel{$(const Halide::Internal::StageSchedule* schedule)->fuse_level().level}
} |]
getFusedPairs :: Ptr CxxStageSchedule -> IO [FusedPair]
getFusedPairs :: Ptr CxxStageSchedule -> IO [FusedPair]
getFusedPairs Ptr CxxStageSchedule
schedule = do
forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::FusedPair>* {
&$(const Halide::Internal::StageSchedule* schedule)->fused_pairs() } |]
peekStageSchedule :: Ptr CxxStageSchedule -> IO StageSchedule
peekStageSchedule :: Ptr CxxStageSchedule -> IO StageSchedule
peekStageSchedule Ptr CxxStageSchedule
schedule = do
[ReductionVariable]
rvars' <- Ptr CxxStageSchedule -> IO [ReductionVariable]
getReductionVariables Ptr CxxStageSchedule
schedule
[Split]
splits' <- Ptr CxxStageSchedule -> IO [Split]
getSplits Ptr CxxStageSchedule
schedule
[Dim]
dims' <- Ptr CxxStageSchedule -> IO [Dim]
getDims Ptr CxxStageSchedule
schedule
let prefetches' :: [a]
prefetches' = []
FuseLoopLevel
fuseLevel' <- Ptr CxxStageSchedule -> IO FuseLoopLevel
getFuseLoopLevel Ptr CxxStageSchedule
schedule
[FusedPair]
fusedPairs' <- Ptr CxxStageSchedule -> IO [FusedPair]
getFusedPairs Ptr CxxStageSchedule
schedule
Bool
allowRaceConditions' <-
forall a. (Eq a, Num a) => a -> Bool
toBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::StageSchedule* schedule)->allow_race_conditions() } |]
Bool
atomic' <-
forall a. (Eq a, Num a) => a -> Bool
toBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::StageSchedule* schedule)->atomic() } |]
Bool
overrideAtomicAssociativityTest' <-
forall a. (Eq a, Num a) => a -> Bool
toBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::StageSchedule* schedule)->override_atomic_associativity_test() } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
StageSchedule
{ $sel:rvars:StageSchedule :: [ReductionVariable]
rvars = [ReductionVariable]
rvars'
, $sel:splits:StageSchedule :: [Split]
splits = [Split]
splits'
, $sel:dims:StageSchedule :: [Dim]
dims = [Dim]
dims'
, $sel:prefetches:StageSchedule :: [PrefetchDirective]
prefetches = forall a. [a]
prefetches'
, $sel:fuseLevel:StageSchedule :: FuseLoopLevel
fuseLevel = FuseLoopLevel
fuseLevel'
, $sel:fusedPairs:StageSchedule :: [FusedPair]
fusedPairs = [FusedPair]
fusedPairs'
, $sel:allowRaceConditions:StageSchedule :: Bool
allowRaceConditions = Bool
allowRaceConditions'
, $sel:atomic:StageSchedule :: Bool
atomic = Bool
atomic'
, $sel:overrideAtomicAssociativityTest:StageSchedule :: Bool
overrideAtomicAssociativityTest = Bool
overrideAtomicAssociativityTest'
}
instance Storable Dim where
sizeOf :: Dim -> Int
sizeOf Dim
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::Dim) } |]
alignment :: Dim -> Int
alignment Dim
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::Dim) } |]
peek :: Ptr Dim -> IO Dim
peek Ptr Dim
p = do
Text
name <- Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(Halide::Internal::Dim* p)->var } |]
ForType
forType' <-
forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(Halide::Internal::Dim* p)->for_type) } |]
DeviceAPI
device <-
forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(Halide::Internal::Dim* p)->device_api) } |]
DimType
dimType' <-
forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(Halide::Internal::Dim* p)->dim_type) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> ForType -> DeviceAPI -> DimType -> Dim
Dim Text
name ForType
forType' DeviceAPI
device DimType
dimType'
poke :: Ptr Dim -> Dim -> IO ()
poke Ptr Dim
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance for Dim does not implement poke"
peekOld :: Ptr Split -> IO Text
peekOld :: Ptr Split -> IO Text
peekOld Ptr Split
p = Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::Split* p)->old_var } |]
peekOuter :: Ptr Split -> IO Text
peekOuter :: Ptr Split -> IO Text
peekOuter Ptr Split
p = Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::Split* p)->outer } |]
peekInner :: Ptr Split -> IO Text
peekInner :: Ptr Split -> IO Text
peekInner Ptr Split
p = Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::Split* p)->inner } |]
peekFactor :: Ptr Split -> IO (Expr Int32)
peekFactor :: Ptr Split -> IO (Expr Int32)
peekFactor Ptr Split
p =
forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
[CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
$(const Halide::Internal::Split* p)->factor} } |]
instance Storable Split where
sizeOf :: Split -> Int
sizeOf Split
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::Split) } |]
alignment :: Split -> Int
alignment Split
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::Split) } |]
peek :: Ptr Split -> IO Split
peek Ptr Split
p = do
Bool
isRename <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_rename() } |]
Bool
isSplit <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_split() } |]
Bool
isFuse <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_fuse() } |]
Bool
isPurify <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_purify() } |]
let r :: IO Split
r
| Bool
isSplit =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SplitContents -> Split
SplitVar forall a b. (a -> b) -> a -> b
$
Text
-> Text
-> Text
-> Expr Int32
-> Bool
-> TailStrategy
-> SplitContents
SplitContents
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Split -> IO Text
peekOld Ptr Split
p
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekOuter Ptr Split
p
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekInner Ptr Split
p
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO (Expr Int32)
peekFactor Ptr Split
p
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->exact } |])
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
(forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral)
[CU.exp| int { static_cast<int>($(const Halide::Internal::Split* p)->tail) } |]
| Bool
isFuse =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FuseContents -> Split
FuseVars forall a b. (a -> b) -> a -> b
$
Text -> Text -> Text -> FuseContents
FuseContents
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Split -> IO Text
peekOuter Ptr Split
p
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekInner Ptr Split
p
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekOld Ptr Split
p
| Bool
isRename = forall a. HasCallStack => FilePath -> a
error FilePath
"renames are not yet implemented"
| Bool
isPurify = forall a. HasCallStack => FilePath -> a
error FilePath
"purify is not yet implemented"
| Bool
otherwise = forall a. HasCallStack => FilePath -> a
error FilePath
"invalid split type"
IO Split
r
poke :: Ptr Split -> Split -> IO ()
poke Ptr Split
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance for Split does not implement poke"
getStageSchedule :: Stage n a -> IO StageSchedule
getStageSchedule :: forall (n :: Nat) a. Stage n a -> IO StageSchedule
getStageSchedule Stage n a
stage =
forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
Ptr CxxStageSchedule -> IO StageSchedule
peekStageSchedule
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const Halide::Internal::StageSchedule* {
&$(const Halide::Stage* stage')->get_schedule() } |]
#if USE_DLOPEN
getHalideLibraryPath :: IO (Maybe Text)
getHalideLibraryPath :: IO (Maybe Text)
getHalideLibraryPath = do
Ptr CxxString
ptr <-
[CU.block| std::string* {
Dl_info info;
if (dladdr((void const*)&Halide::load_plugin, &info) != 0 && info.dli_sname != nullptr) {
auto symbol = dlsym(RTLD_NEXT, info.dli_sname);
if (dladdr(symbol, &info) != 0 && info.dli_fname != nullptr) {
return new std::string{info.dli_fname};
}
}
return nullptr;
} |]
if Ptr CxxString
ptr forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
else forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> Text
pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
takeDirectory forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> FilePath
unpack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CxxString -> IO Text
peekAndDeleteCxxString Ptr CxxString
ptr
#else
getHalideLibraryPath :: IO (Maybe Text)
getHalideLibraryPath = pure Nothing
#endif
data AutoScheduler
= Adams2019
| Li2018
| Mullapudi2016
deriving stock (AutoScheduler -> AutoScheduler -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AutoScheduler -> AutoScheduler -> Bool
$c/= :: AutoScheduler -> AutoScheduler -> Bool
== :: AutoScheduler -> AutoScheduler -> Bool
$c== :: AutoScheduler -> AutoScheduler -> Bool
Eq, Int -> AutoScheduler -> ShowS
[AutoScheduler] -> ShowS
AutoScheduler -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [AutoScheduler] -> ShowS
$cshowList :: [AutoScheduler] -> ShowS
show :: AutoScheduler -> FilePath
$cshow :: AutoScheduler -> FilePath
showsPrec :: Int -> AutoScheduler -> ShowS
$cshowsPrec :: Int -> AutoScheduler -> ShowS
Show)
loadAutoScheduler :: AutoScheduler -> IO ()
loadAutoScheduler :: AutoScheduler -> IO ()
loadAutoScheduler AutoScheduler
scheduler = do
Maybe Text
lib <- IO (Maybe Text)
getHalideLibraryPath
let prepare :: Text -> Text
prepare Text
s
| Just Text
dir <- Maybe Text
lib = Text
dir forall a. Semigroup a => a -> a -> a
<> Text
"/lib" forall a. Semigroup a => a -> a -> a
<> Text
s forall a. Semigroup a => a -> a -> a
<> Text
".so"
| Maybe Text
Nothing <- Maybe Text
lib = Text
"lib" forall a. Semigroup a => a -> a -> a
<> Text
s forall a. Semigroup a => a -> a -> a
<> Text
".so"
path :: Text
path = Text -> Text
prepare forall a b. (a -> b) -> a -> b
$
case AutoScheduler
scheduler of
AutoScheduler
Adams2019 -> Text
"autoschedule_adams2019"
AutoScheduler
Li2018 -> Text
"autoschedule_li2018"
AutoScheduler
Mullapudi2016 -> Text
"autoschedule_mullapudi2016"
Text -> IO ()
loadLibrary Text
path
applyAutoScheduler :: KnownNat n => AutoScheduler -> Target -> Func t n a -> IO Text
applyAutoScheduler :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
AutoScheduler -> Target -> Func t n a -> IO Text
applyAutoScheduler AutoScheduler
scheduler Target
target Func t n a
func = do
let s :: ByteString
s = Text -> ByteString
encodeUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> Text
pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> FilePath
show forall a b. (a -> b) -> a -> b
$ AutoScheduler
scheduler
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
t -> do
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| std::string* {
return handle_halide_exceptions([=](){
auto name = std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)};
auto pipeline = Halide::Pipeline{*$(Halide::Func* f)};
auto params = Halide::AutoschedulerParams{name};
auto results = pipeline.apply_autoscheduler(*$(Halide::Target* t), params);
return new std::string{std::move(results.schedule_source)};
});
} |]
makeUnqualified :: Text -> Text
makeUnqualified :: Text -> Text
makeUnqualified = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> (Text, Text)
T.breakOnEnd Text
"."
applySplit :: KnownNat n => Split -> Stage n a -> IO ()
applySplit :: forall (n :: Nat) a. KnownNat n => Split -> Stage n a -> IO ()
applySplit (SplitVar SplitContents
x) Stage n a
stage = do
Expr Int32
oldVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified SplitContents
x.splitOld)
Expr Int32
outerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified SplitContents
x.splitOuter)
Expr Int32
innerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified SplitContents
x.splitInner)
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
TailStrategy
-> Expr Int32
-> (Expr Int32, Expr Int32)
-> Expr Int32
-> f n a
-> IO (f n a)
Language.Halide.Func.split SplitContents
x.splitTail Expr Int32
oldVar (Expr Int32
outerVar, Expr Int32
innerVar) SplitContents
x.splitFactor Stage n a
stage
applySplit (FuseVars FuseContents
x) Stage n a
stage = do
Expr Int32
newVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified FuseContents
x.fuseNew)
Expr Int32
innerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified FuseContents
x.fuseInner)
Expr Int32
outerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified FuseContents
x.fuseOuter)
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
(Expr Int32, Expr Int32) -> Expr Int32 -> f n a -> IO (f n a)
Language.Halide.Func.fuse (Expr Int32
innerVar, Expr Int32
outerVar) Expr Int32
newVar Stage n a
stage
applySplits :: KnownNat n => [Split] -> Stage n a -> IO ()
applySplits :: forall (n :: Nat) a. KnownNat n => [Split] -> Stage n a -> IO ()
applySplits [Split]
xs Stage n a
stage = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (n :: Nat) a. KnownNat n => Split -> Stage n a -> IO ()
`applySplit` Stage n a
stage) [Split]
xs
applyDim :: KnownNat n => Dim -> Stage n a -> IO ()
applyDim :: forall (n :: Nat) a. KnownNat n => Dim -> Stage n a -> IO ()
applyDim Dim
x Stage n a
stage = do
Expr Int32
var' <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified Dim
x.var)
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
case Dim
x.forType of
ForType
ForSerial -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
ForType
ForParallel -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Int32 -> f n a -> IO (f n a)
parallel Expr Int32
var' Stage n a
stage
ForType
ForVectorized -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Int32 -> f n a -> IO (f n a)
vectorize Expr Int32
var' Stage n a
stage
ForType
ForUnrolled -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Int32 -> f n a -> IO (f n a)
unroll Expr Int32
var' Stage n a
stage
ForType
ForExtern -> forall a. HasCallStack => FilePath -> a
error FilePath
"extern ForType is not yet supported by applyDim"
ForType
ForGPUBlock -> forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuBlocks Dim
x.deviceApi Expr Int32
var' Stage n a
stage
ForType
ForGPUThread -> forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuThreads Dim
x.deviceApi Expr Int32
var' Stage n a
stage
ForType
ForGPULane -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
DeviceAPI -> Expr Int32 -> f n a -> IO (f n a)
gpuLanes Dim
x.deviceApi Expr Int32
var' Stage n a
stage
applyDims :: KnownNat n => [Dim] -> Stage n a -> IO ()
applyDims :: forall (n :: Nat) a. KnownNat n => [Dim] -> Stage n a -> IO ()
applyDims [Dim]
xs Stage n a
stage = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (n :: Nat) a. KnownNat n => Dim -> Stage n a -> IO ()
`applyDim` Stage n a
stage) [Dim]
xs
[Expr Int32]
vars <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Text -> IO (Expr Int32)
mkVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
makeUnqualified forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.var)) [Dim]
xs
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
[Expr Int32] -> f n a -> IO (f n a)
reorder [Expr Int32]
vars Stage n a
stage
applySchedule :: KnownNat n => StageSchedule -> Stage n a -> IO ()
applySchedule :: forall (n :: Nat) a.
KnownNat n =>
StageSchedule -> Stage n a -> IO ()
applySchedule StageSchedule
schedule Stage n a
stage = do
forall (n :: Nat) a. KnownNat n => [Split] -> Stage n a -> IO ()
applySplits StageSchedule
schedule.splits Stage n a
stage
forall (n :: Nat) a. KnownNat n => [Dim] -> Stage n a -> IO ()
applyDims StageSchedule
schedule.dims Stage n a
stage