{-# LANGUAGE CPP #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

module Language.Halide.Schedule
  ( Dim (..)
  , DimType (..)
  , ForType (..)
  , SplitContents (..)
  , FuseContents (..)
  , Split (..)
  , Bound (..)
  , StorageDim (..)
  , FusedPair (..)
  , FuseLoopLevel (..)
  , StageSchedule (..)
  , ReductionVariable (..)
  , PrefetchDirective (..)
  , getStageSchedule
  -- , getStageSchedule
  -- , getFusedPairs
  -- , getReductionVariables
  -- , getFuseLoopLevel
  -- , getDims
  -- , getSplits
  , AutoScheduler (..)
  , loadAutoScheduler
  , applyAutoScheduler
  , getHalideLibraryPath
  , applySplits
  , applyDims
  , applySchedule
  )
where

import Control.Monad (void)
import Data.Text (Text, pack, unpack)
import Data.Text qualified as T
import Data.Text.Encoding (encodeUtf8)
import Foreign.C.Types (CInt (..))
import Foreign.ForeignPtr
import Foreign.Marshal (allocaArray, peekArray, toBool)
import Foreign.Ptr (Ptr, nullPtr)
import Foreign.Storable
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Func
import Language.Halide.LoopLevel
import Language.Halide.Target
import Language.Halide.Type
import Language.Halide.Utils
import System.FilePath (takeDirectory)
import Prelude hiding (tail)

#if USE_DLOPEN
import qualified System.Posix.DynamicLinker as DL

loadLibrary :: Text -> IO ()
loadLibrary :: Text -> IO ()
loadLibrary Text
path = do
  DL
_ <- FilePath -> [RTLDFlags] -> IO DL
DL.dlopen (Text -> FilePath
unpack Text
path) [RTLDFlags
DL.RTLD_LAZY]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

#else
import qualified System.Win32.DLL as Win32

loadLibrary :: Text -> IO ()
loadLibrary path = do
  _ <- Win32.loadLibrary (unpack path)
  pure ()
#endif

-- | Type of dimension that tells which transformations are legal on it.
data DimType = DimPureVar | DimPureRVar | DimImpureRVar
  deriving stock (Int -> DimType -> ShowS
[DimType] -> ShowS
DimType -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [DimType] -> ShowS
$cshowList :: [DimType] -> ShowS
show :: DimType -> FilePath
$cshow :: DimType -> FilePath
showsPrec :: Int -> DimType -> ShowS
$cshowsPrec :: Int -> DimType -> ShowS
Show, DimType -> DimType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: DimType -> DimType -> Bool
$c/= :: DimType -> DimType -> Bool
== :: DimType -> DimType -> Bool
$c== :: DimType -> DimType -> Bool
Eq)

-- | Specifies how loop values are traversed.
data ForType
  = ForSerial
  | ForParallel
  | ForVectorized
  | ForUnrolled
  | ForExtern
  | ForGPUBlock
  | ForGPUThread
  | ForGPULane
  deriving stock (Int -> ForType -> ShowS
[ForType] -> ShowS
ForType -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [ForType] -> ShowS
$cshowList :: [ForType] -> ShowS
show :: ForType -> FilePath
$cshow :: ForType -> FilePath
showsPrec :: Int -> ForType -> ShowS
$cshowsPrec :: Int -> ForType -> ShowS
Show, ForType -> ForType -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ForType -> ForType -> Bool
$c/= :: ForType -> ForType -> Bool
== :: ForType -> ForType -> Bool
$c== :: ForType -> ForType -> Bool
Eq)

data Dim = Dim {Dim -> Text
var :: !Text, Dim -> ForType
forType :: !ForType, Dim -> DeviceAPI
deviceApi :: !DeviceAPI, Dim -> DimType
dimType :: !DimType}
  deriving stock (Int -> Dim -> ShowS
[Dim] -> ShowS
Dim -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Dim] -> ShowS
$cshowList :: [Dim] -> ShowS
show :: Dim -> FilePath
$cshow :: Dim -> FilePath
showsPrec :: Int -> Dim -> ShowS
$cshowsPrec :: Int -> Dim -> ShowS
Show, Dim -> Dim -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Dim -> Dim -> Bool
$c/= :: Dim -> Dim -> Bool
== :: Dim -> Dim -> Bool
$c== :: Dim -> Dim -> Bool
Eq)

data FuseContents = FuseContents
  { FuseContents -> Text
fuseOuter :: !Text
  , FuseContents -> Text
fuseInner :: !Text
  , FuseContents -> Text
fuseNew :: !Text
  }
  deriving stock (Int -> FuseContents -> ShowS
[FuseContents] -> ShowS
FuseContents -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [FuseContents] -> ShowS
$cshowList :: [FuseContents] -> ShowS
show :: FuseContents -> FilePath
$cshow :: FuseContents -> FilePath
showsPrec :: Int -> FuseContents -> ShowS
$cshowsPrec :: Int -> FuseContents -> ShowS
Show, FuseContents -> FuseContents -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuseContents -> FuseContents -> Bool
$c/= :: FuseContents -> FuseContents -> Bool
== :: FuseContents -> FuseContents -> Bool
$c== :: FuseContents -> FuseContents -> Bool
Eq)

data SplitContents = SplitContents
  { SplitContents -> Text
splitOld :: !Text
  , SplitContents -> Text
splitOuter :: !Text
  , SplitContents -> Text
splitInner :: !Text
  , SplitContents -> Expr Int32
splitFactor :: !(Expr Int32)
  , SplitContents -> Bool
splitExact :: !Bool
  , SplitContents -> TailStrategy
splitTail :: !TailStrategy
  }
  deriving stock (Int -> SplitContents -> ShowS
[SplitContents] -> ShowS
SplitContents -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [SplitContents] -> ShowS
$cshowList :: [SplitContents] -> ShowS
show :: SplitContents -> FilePath
$cshow :: SplitContents -> FilePath
showsPrec :: Int -> SplitContents -> ShowS
$cshowsPrec :: Int -> SplitContents -> ShowS
Show)

data Split
  = SplitVar !SplitContents
  | FuseVars !FuseContents
  deriving stock (Int -> Split -> ShowS
[Split] -> ShowS
Split -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Split] -> ShowS
$cshowList :: [Split] -> ShowS
show :: Split -> FilePath
$cshow :: Split -> FilePath
showsPrec :: Int -> Split -> ShowS
$cshowsPrec :: Int -> Split -> ShowS
Show)

data Bound = Bound
  { Bound -> Text
boundVar :: !Text
  , Bound -> Maybe (Expr Int32)
boundMin :: !(Maybe (Expr Int32))
  , Bound -> Expr Int32
boundExtent :: !(Expr Int32)
  , Bound -> Maybe (Expr Int32)
boundModulus :: !(Maybe (Expr Int32))
  , Bound -> Maybe (Expr Int32)
boundRemainder :: !(Maybe (Expr Int32))
  }
  deriving stock (Int -> Bound -> ShowS
[Bound] -> ShowS
Bound -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [Bound] -> ShowS
$cshowList :: [Bound] -> ShowS
show :: Bound -> FilePath
$cshow :: Bound -> FilePath
showsPrec :: Int -> Bound -> ShowS
$cshowsPrec :: Int -> Bound -> ShowS
Show)

data StorageDim = StorageDim
  { StorageDim -> Text
storageVar :: !Text
  , StorageDim -> Maybe (Expr Int32)
storageAlignment :: !(Maybe (Expr Int32))
  , StorageDim -> Maybe (Expr Int32)
storageBound :: !(Maybe (Expr Int32))
  , StorageDim -> Maybe (Expr Int32, Bool)
storageFold :: !(Maybe (Expr Int32, Bool))
  }
  deriving stock (Int -> StorageDim -> ShowS
[StorageDim] -> ShowS
StorageDim -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [StorageDim] -> ShowS
$cshowList :: [StorageDim] -> ShowS
show :: StorageDim -> FilePath
$cshow :: StorageDim -> FilePath
showsPrec :: Int -> StorageDim -> ShowS
$cshowsPrec :: Int -> StorageDim -> ShowS
Show)

data FusedPair = FusedPair !Text !(Text, Int) !(Text, Int)
  deriving stock (Int -> FusedPair -> ShowS
[FusedPair] -> ShowS
FusedPair -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [FusedPair] -> ShowS
$cshowList :: [FusedPair] -> ShowS
show :: FusedPair -> FilePath
$cshow :: FusedPair -> FilePath
showsPrec :: Int -> FusedPair -> ShowS
$cshowsPrec :: Int -> FusedPair -> ShowS
Show, FusedPair -> FusedPair -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FusedPair -> FusedPair -> Bool
$c/= :: FusedPair -> FusedPair -> Bool
== :: FusedPair -> FusedPair -> Bool
$c== :: FusedPair -> FusedPair -> Bool
Eq)

data FuseLoopLevel = FuseLoopLevel !SomeLoopLevel
  deriving stock (Int -> FuseLoopLevel -> ShowS
[FuseLoopLevel] -> ShowS
FuseLoopLevel -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [FuseLoopLevel] -> ShowS
$cshowList :: [FuseLoopLevel] -> ShowS
show :: FuseLoopLevel -> FilePath
$cshow :: FuseLoopLevel -> FilePath
showsPrec :: Int -> FuseLoopLevel -> ShowS
$cshowsPrec :: Int -> FuseLoopLevel -> ShowS
Show, FuseLoopLevel -> FuseLoopLevel -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FuseLoopLevel -> FuseLoopLevel -> Bool
$c/= :: FuseLoopLevel -> FuseLoopLevel -> Bool
== :: FuseLoopLevel -> FuseLoopLevel -> Bool
$c== :: FuseLoopLevel -> FuseLoopLevel -> Bool
Eq)

data ReductionVariable = ReductionVariable {ReductionVariable -> Text
varName :: !Text, ReductionVariable -> Expr Int32
minExpr :: !(Expr Int32), ReductionVariable -> Expr Int32
extentExpr :: !(Expr Int32)}
  deriving stock (Int -> ReductionVariable -> ShowS
[ReductionVariable] -> ShowS
ReductionVariable -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [ReductionVariable] -> ShowS
$cshowList :: [ReductionVariable] -> ShowS
show :: ReductionVariable -> FilePath
$cshow :: ReductionVariable -> FilePath
showsPrec :: Int -> ReductionVariable -> ShowS
$cshowsPrec :: Int -> ReductionVariable -> ShowS
Show)

data PrefetchBoundStrategy
  = PrefetchClamp
  | PrefetchGuardWithIf
  | PrefetchNonFaulting
  deriving stock (Int -> PrefetchBoundStrategy -> ShowS
[PrefetchBoundStrategy] -> ShowS
PrefetchBoundStrategy -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [PrefetchBoundStrategy] -> ShowS
$cshowList :: [PrefetchBoundStrategy] -> ShowS
show :: PrefetchBoundStrategy -> FilePath
$cshow :: PrefetchBoundStrategy -> FilePath
showsPrec :: Int -> PrefetchBoundStrategy -> ShowS
$cshowsPrec :: Int -> PrefetchBoundStrategy -> ShowS
Show, PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
$c/= :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
== :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
$c== :: PrefetchBoundStrategy -> PrefetchBoundStrategy -> Bool
Eq)

data PrefetchDirective = PrefetchDirective
  { PrefetchDirective -> Text
prefetchFunc :: !Text
  , PrefetchDirective -> Text
prefetchAt :: !Text
  , PrefetchDirective -> Text
prefetchFrom :: !Text
  , PrefetchDirective -> Expr Int32
prefetchOffset :: !(Expr Int32)
  , PrefetchDirective -> PrefetchBoundStrategy
prefetchStrategy :: !PrefetchBoundStrategy
  , PrefetchDirective -> Maybe (ForeignPtr CxxParameter)
prefetchParameter :: !(Maybe (ForeignPtr CxxParameter))
  }
  deriving stock (Int -> PrefetchDirective -> ShowS
[PrefetchDirective] -> ShowS
PrefetchDirective -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [PrefetchDirective] -> ShowS
$cshowList :: [PrefetchDirective] -> ShowS
show :: PrefetchDirective -> FilePath
$cshow :: PrefetchDirective -> FilePath
showsPrec :: Int -> PrefetchDirective -> ShowS
$cshowsPrec :: Int -> PrefetchDirective -> ShowS
Show)

data StageSchedule = StageSchedule
  { StageSchedule -> [ReductionVariable]
rvars :: ![ReductionVariable]
  , StageSchedule -> [Split]
splits :: ![Split]
  , StageSchedule -> [Dim]
dims :: ![Dim]
  , StageSchedule -> [PrefetchDirective]
prefetches :: ![PrefetchDirective]
  , StageSchedule -> FuseLoopLevel
fuseLevel :: !FuseLoopLevel
  , StageSchedule -> [FusedPair]
fusedPairs :: ![FusedPair]
  , StageSchedule -> Bool
allowRaceConditions :: !Bool
  , StageSchedule -> Bool
atomic :: !Bool
  , StageSchedule -> Bool
overrideAtomicAssociativityTest :: !Bool
  }
  deriving stock (Int -> StageSchedule -> ShowS
[StageSchedule] -> ShowS
StageSchedule -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [StageSchedule] -> ShowS
$cshowList :: [StageSchedule] -> ShowS
show :: StageSchedule -> FilePath
$cshow :: StageSchedule -> FilePath
showsPrec :: Int -> StageSchedule -> ShowS
$cshowsPrec :: Int -> StageSchedule -> ShowS
Show)

importHalide

instanceHasCxxVector "Halide::Internal::Dim"
instanceHasCxxVector "Halide::Internal::Split"
instanceHasCxxVector "Halide::Internal::FusedPair"
instanceHasCxxVector "Halide::Internal::ReductionVariable"

instance Enum ForType where
  toEnum :: Int -> ForType
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Serial) } |] =
        ForType
ForSerial
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Parallel) } |] =
        ForType
ForParallel
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Vectorized) } |] =
        ForType
ForVectorized
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Unrolled) } |] =
        ForType
ForUnrolled
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::Extern) } |] =
        ForType
ForExtern
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::GPUBlock) } |] =
        ForType
ForGPUBlock
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::GPUThread) } |] =
        ForType
ForGPUThread
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::ForType::GPULane) } |] =
        ForType
ForGPULane
    | Bool
otherwise = forall a. HasCallStack => FilePath -> a
error forall a b. (a -> b) -> a -> b
$ FilePath
"invalid ForType: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
k
  fromEnum :: ForType -> Int
fromEnum = forall a. HasCallStack => FilePath -> a
error FilePath
"Enum instance for ForType does not implement fromEnum"

instance Enum DimType where
  toEnum :: Int -> DimType
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::DimType::PureVar) } |] =
        DimType
DimPureVar
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::DimType::PureRVar) } |] =
        DimType
DimPureRVar
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::Internal::DimType::ImpureRVar) } |] =
        DimType
DimImpureRVar
    | Bool
otherwise = forall a. HasCallStack => FilePath -> a
error forall a b. (a -> b) -> a -> b
$ FilePath
"invalid DimType: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
k
  fromEnum :: DimType -> Int
fromEnum = forall a. HasCallStack => FilePath -> a
error FilePath
"Enum instance for DimType does not implement fromEnum"

instance Enum PrefetchBoundStrategy where
  toEnum :: Int -> PrefetchBoundStrategy
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::PrefetchBoundStrategy::Clamp) } |] =
        PrefetchBoundStrategy
PrefetchClamp
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::PrefetchBoundStrategy::GuardWithIf) } |] =
        PrefetchBoundStrategy
PrefetchGuardWithIf
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::PrefetchBoundStrategy::NonFaulting) } |] =
        PrefetchBoundStrategy
PrefetchNonFaulting
    | Bool
otherwise = forall a. HasCallStack => FilePath -> a
error forall a b. (a -> b) -> a -> b
$ FilePath
"invalid PrefetchBoundStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> FilePath
show Int
k
  fromEnum :: PrefetchBoundStrategy -> Int
fromEnum = forall a. HasCallStack => FilePath -> a
error FilePath
"Enum instance for ForType does not implement fromEnum"

instance Storable FusedPair where
  sizeOf :: FusedPair -> Int
sizeOf FusedPair
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::FusedPair) } |]
  alignment :: FusedPair -> Int
alignment FusedPair
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::FusedPair) } |]
  peek :: Ptr FusedPair -> IO FusedPair
peek Ptr FusedPair
p = do
    Text
func1 <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::FusedPair* p)->func_1 } |]
    Text
func2 <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::FusedPair* p)->func_2 } |]
    Int
stage1 <-
      forall a b. (Integral a, Num b) => a -> b
fromIntegral
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| size_t { $(const Halide::Internal::FusedPair* p)->stage_1 } |]
    Int
stage2 <-
      forall a b. (Integral a, Num b) => a -> b
fromIntegral
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| size_t { $(const Halide::Internal::FusedPair* p)->stage_2 } |]
    Text
varName <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::FusedPair* p)->var_name } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> (Text, Int) -> (Text, Int) -> FusedPair
FusedPair Text
varName (Text
func1, Int
stage1) (Text
func2, Int
stage2)
  poke :: Ptr FusedPair -> FusedPair -> IO ()
poke Ptr FusedPair
_ FusedPair
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance of FusedPair does not implement poke"

instance Storable ReductionVariable where
  sizeOf :: ReductionVariable -> Int
sizeOf ReductionVariable
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::ReductionVariable) } |]
  alignment :: ReductionVariable -> Int
alignment ReductionVariable
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::ReductionVariable) } |]
  peek :: Ptr ReductionVariable -> IO ReductionVariable
peek Ptr ReductionVariable
p = do
    Text
varName <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::ReductionVariable* p)->var } |]
    Expr Int32
minExpr <-
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(const Halide::Internal::ReductionVariable* p)->min} } |]
    Expr Int32
extentExpr <-
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(const Halide::Internal::ReductionVariable* p)->extent} } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> Expr Int32 -> Expr Int32 -> ReductionVariable
ReductionVariable Text
varName Expr Int32
minExpr Expr Int32
extentExpr
  poke :: Ptr ReductionVariable -> ReductionVariable -> IO ()
poke Ptr ReductionVariable
_ ReductionVariable
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance of ReductionVariable does not implement poke"

instance Storable PrefetchDirective where
  sizeOf :: PrefetchDirective -> Int
sizeOf PrefetchDirective
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::PrefetchDirective) } |]
  alignment :: PrefetchDirective -> Int
alignment PrefetchDirective
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::PrefetchDirective) } |]
  peek :: Ptr PrefetchDirective -> IO PrefetchDirective
peek Ptr PrefetchDirective
p = do
    Text
funcName' <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::PrefetchDirective* p)->name } |]
    Text
atVar' <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::PrefetchDirective* p)->at } |]
    Text
fromVar' <-
      Ptr CxxString -> IO Text
peekCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::PrefetchDirective* p)->from } |]
    Expr Int32
offset' <-
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(const Halide::Internal::PrefetchDirective* p)->offset} } |]
    PrefetchBoundStrategy
strategy' <-
      forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(const Halide::Internal::PrefetchDirective* p)->strategy) } |]
    -- isDefined <-
    --   toBool
    --     <$> [CU.exp| bool { $(const Halide::Internal::PrefetchDirective* p)->param.defined() } |]
    -- param' <-
    --   if isDefined
    --     then
    --       fmap Just $
    --         wrapCxxParameter
    --           =<< [CU.exp| Halide::Internal::Parameter* {
    --                 new Halide::Internal::Parameter{$(const Halide::Internal::PrefetchDirective* p)->param} } |]
    --     else pure Nothing
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text
-> Text
-> Text
-> Expr Int32
-> PrefetchBoundStrategy
-> Maybe (ForeignPtr CxxParameter)
-> PrefetchDirective
PrefetchDirective Text
funcName' Text
atVar' Text
fromVar' Expr Int32
offset' PrefetchBoundStrategy
strategy' forall a. Maybe a
Nothing
  poke :: Ptr PrefetchDirective -> PrefetchDirective -> IO ()
poke Ptr PrefetchDirective
_ PrefetchDirective
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance for PrefetchDirective does not implement poke"

getReductionVariables :: Ptr CxxStageSchedule -> IO [ReductionVariable]
getReductionVariables :: Ptr CxxStageSchedule -> IO [ReductionVariable]
getReductionVariables Ptr CxxStageSchedule
schedule =
  forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::ReductionVariable>* {
          &$(const Halide::Internal::StageSchedule* schedule)->rvars() } |]

getSplits :: Ptr CxxStageSchedule -> IO [Split]
getSplits :: Ptr CxxStageSchedule -> IO [Split]
getSplits Ptr CxxStageSchedule
schedule =
  forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::Split>* {
            &$(const Halide::Internal::StageSchedule* schedule)->splits() } |]

getDims :: Ptr CxxStageSchedule -> IO [Dim]
getDims :: Ptr CxxStageSchedule -> IO [Dim]
getDims Ptr CxxStageSchedule
schedule =
  forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::Dim>* {
          &$(const Halide::Internal::StageSchedule* schedule)->dims() } |]

getFuseLoopLevel :: Ptr CxxStageSchedule -> IO FuseLoopLevel
getFuseLoopLevel :: Ptr CxxStageSchedule -> IO FuseLoopLevel
getFuseLoopLevel Ptr CxxStageSchedule
schedule =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SomeLoopLevel -> FuseLoopLevel
FuseLoopLevel forall a b. (a -> b) -> a -> b
$
    Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::LoopLevel* {
            new Halide::LoopLevel{$(const Halide::Internal::StageSchedule* schedule)->fuse_level().level}
          } |]

getFusedPairs :: Ptr CxxStageSchedule -> IO [FusedPair]
getFusedPairs :: Ptr CxxStageSchedule -> IO [FusedPair]
getFusedPairs Ptr CxxStageSchedule
schedule = do
  forall a.
(HasCxxVector a, Storable a) =>
Ptr (CxxVector a) -> IO [a]
peekCxxVector
    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::vector<Halide::Internal::FusedPair>* {
          &$(const Halide::Internal::StageSchedule* schedule)->fused_pairs() } |]

peekStageSchedule :: Ptr CxxStageSchedule -> IO StageSchedule
peekStageSchedule :: Ptr CxxStageSchedule -> IO StageSchedule
peekStageSchedule Ptr CxxStageSchedule
schedule = do
  [ReductionVariable]
rvars' <- Ptr CxxStageSchedule -> IO [ReductionVariable]
getReductionVariables Ptr CxxStageSchedule
schedule
  [Split]
splits' <- Ptr CxxStageSchedule -> IO [Split]
getSplits Ptr CxxStageSchedule
schedule
  [Dim]
dims' <- Ptr CxxStageSchedule -> IO [Dim]
getDims Ptr CxxStageSchedule
schedule
  let prefetches' :: [a]
prefetches' = []
  FuseLoopLevel
fuseLevel' <- Ptr CxxStageSchedule -> IO FuseLoopLevel
getFuseLoopLevel Ptr CxxStageSchedule
schedule
  [FusedPair]
fusedPairs' <- Ptr CxxStageSchedule -> IO [FusedPair]
getFusedPairs Ptr CxxStageSchedule
schedule
  Bool
allowRaceConditions' <-
    forall a. (Eq a, Num a) => a -> Bool
toBool
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::StageSchedule* schedule)->allow_race_conditions() } |]
  Bool
atomic' <-
    forall a. (Eq a, Num a) => a -> Bool
toBool
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::StageSchedule* schedule)->atomic() } |]
  Bool
overrideAtomicAssociativityTest' <-
    forall a. (Eq a, Num a) => a -> Bool
toBool
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::StageSchedule* schedule)->override_atomic_associativity_test() } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
    StageSchedule
      { $sel:rvars:StageSchedule :: [ReductionVariable]
rvars = [ReductionVariable]
rvars'
      , $sel:splits:StageSchedule :: [Split]
splits = [Split]
splits'
      , $sel:dims:StageSchedule :: [Dim]
dims = [Dim]
dims'
      , $sel:prefetches:StageSchedule :: [PrefetchDirective]
prefetches = forall a. [a]
prefetches'
      , $sel:fuseLevel:StageSchedule :: FuseLoopLevel
fuseLevel = FuseLoopLevel
fuseLevel'
      , $sel:fusedPairs:StageSchedule :: [FusedPair]
fusedPairs = [FusedPair]
fusedPairs'
      , $sel:allowRaceConditions:StageSchedule :: Bool
allowRaceConditions = Bool
allowRaceConditions'
      , $sel:atomic:StageSchedule :: Bool
atomic = Bool
atomic'
      , $sel:overrideAtomicAssociativityTest:StageSchedule :: Bool
overrideAtomicAssociativityTest = Bool
overrideAtomicAssociativityTest'
      }

instance Storable Dim where
  sizeOf :: Dim -> Int
sizeOf Dim
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::Dim) } |]
  alignment :: Dim -> Int
alignment Dim
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::Dim) } |]
  peek :: Ptr Dim -> IO Dim
peek Ptr Dim
p = do
    Text
name <- Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(Halide::Internal::Dim* p)->var } |]
    ForType
forType' <-
      forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(Halide::Internal::Dim* p)->for_type) } |]
    DeviceAPI
device <-
      forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(Halide::Internal::Dim* p)->device_api) } |]
    DimType
dimType' <-
      forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { static_cast<int>($(Halide::Internal::Dim* p)->dim_type) } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text -> ForType -> DeviceAPI -> DimType -> Dim
Dim Text
name ForType
forType' DeviceAPI
device DimType
dimType'
  poke :: Ptr Dim -> Dim -> IO ()
poke Ptr Dim
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance for Dim does not implement poke"

peekOld :: Ptr Split -> IO Text
peekOld :: Ptr Split -> IO Text
peekOld Ptr Split
p = Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::Split* p)->old_var } |]

peekOuter :: Ptr Split -> IO Text
peekOuter :: Ptr Split -> IO Text
peekOuter Ptr Split
p = Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::Split* p)->outer } |]

peekInner :: Ptr Split -> IO Text
peekInner :: Ptr Split -> IO Text
peekInner Ptr Split
p = Ptr CxxString -> IO Text
peekCxxString forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const std::string* { &$(const Halide::Internal::Split* p)->inner } |]

peekFactor :: Ptr Split -> IO (Expr Int32)
peekFactor :: Ptr Split -> IO (Expr Int32)
peekFactor Ptr Split
p =
  forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
    [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
      $(const Halide::Internal::Split* p)->factor} } |]

instance Storable Split where
  sizeOf :: Split -> Int
sizeOf Split
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { sizeof(Halide::Internal::Split) } |]
  alignment :: Split -> Int
alignment Split
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral [CU.pure| size_t { alignof(Halide::Internal::Split) } |]
  peek :: Ptr Split -> IO Split
peek Ptr Split
p = do
    Bool
isRename <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_rename() } |]
    Bool
isSplit <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_split() } |]
    Bool
isFuse <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_fuse() } |]
    Bool
isPurify <- forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->is_purify() } |]
    let r :: IO Split
r
          | Bool
isSplit =
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SplitContents -> Split
SplitVar forall a b. (a -> b) -> a -> b
$
                Text
-> Text
-> Text
-> Expr Int32
-> Bool
-> TailStrategy
-> SplitContents
SplitContents
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Split -> IO Text
peekOld Ptr Split
p
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekOuter Ptr Split
p
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekInner Ptr Split
p
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO (Expr Int32)
peekFactor Ptr Split
p
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. (Eq a, Num a) => a -> Bool
toBool forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| bool { $(const Halide::Internal::Split* p)->exact } |])
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
                    (forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral)
                    [CU.exp| int { static_cast<int>($(const Halide::Internal::Split* p)->tail) } |]
          | Bool
isFuse =
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FuseContents -> Split
FuseVars forall a b. (a -> b) -> a -> b
$
                Text -> Text -> Text -> FuseContents
FuseContents
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Split -> IO Text
peekOuter Ptr Split
p
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekInner Ptr Split
p
                  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Ptr Split -> IO Text
peekOld Ptr Split
p
          | Bool
isRename = forall a. HasCallStack => FilePath -> a
error FilePath
"renames are not yet implemented"
          | Bool
isPurify = forall a. HasCallStack => FilePath -> a
error FilePath
"purify is not yet implemented"
          | Bool
otherwise = forall a. HasCallStack => FilePath -> a
error FilePath
"invalid split type"
    IO Split
r
  poke :: Ptr Split -> Split -> IO ()
poke Ptr Split
_ = forall a. HasCallStack => FilePath -> a
error FilePath
"Storable instance for Split does not implement poke"

-- wrapCxxStageSchedule :: Ptr CxxStageSchedule -> IO StageSchedule
-- wrapCxxStageSchedule = fmap StageSchedule . newForeignPtr deleter
--   where
--     deleter =
--       [C.funPtr| void deleteSchedule(Halide::Internal::StageSchedule* p) {
--         std::cout << "deleting ..." << std::endl;
--         delete p; } |]

getStageSchedule :: Stage n a -> IO StageSchedule
getStageSchedule :: forall (n :: Nat) a. Stage n a -> IO StageSchedule
getStageSchedule Stage n a
stage =
  forall (n :: Nat) a b. Stage n a -> (Ptr CxxStage -> IO b) -> IO b
withCxxStage Stage n a
stage forall a b. (a -> b) -> a -> b
$ \Ptr CxxStage
stage' ->
    Ptr CxxStageSchedule -> IO StageSchedule
peekStageSchedule
      forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const Halide::Internal::StageSchedule* {
            &$(const Halide::Stage* stage')->get_schedule() } |]

#if USE_DLOPEN

getHalideLibraryPath :: IO (Maybe Text)
getHalideLibraryPath :: IO (Maybe Text)
getHalideLibraryPath = do
  Ptr CxxString
ptr <-
    [CU.block| std::string* {
      Dl_info info;
      if (dladdr((void const*)&Halide::load_plugin, &info) != 0 && info.dli_sname != nullptr) {
        auto symbol = dlsym(RTLD_NEXT, info.dli_sname);
        if (dladdr(symbol, &info) != 0 && info.dli_fname != nullptr) {
          return new std::string{info.dli_fname};
        }
      }
      return nullptr;
    } |]
  if Ptr CxxString
ptr forall a. Eq a => a -> a -> Bool
== forall a. Ptr a
nullPtr
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    else forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> Text
pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS
takeDirectory forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> FilePath
unpack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CxxString -> IO Text
peekAndDeleteCxxString Ptr CxxString
ptr

#else

getHalideLibraryPath :: IO (Maybe Text)
getHalideLibraryPath = pure Nothing

#endif

data AutoScheduler
  = Adams2019
  | Li2018
  | Mullapudi2016
  deriving stock (AutoScheduler -> AutoScheduler -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AutoScheduler -> AutoScheduler -> Bool
$c/= :: AutoScheduler -> AutoScheduler -> Bool
== :: AutoScheduler -> AutoScheduler -> Bool
$c== :: AutoScheduler -> AutoScheduler -> Bool
Eq, Int -> AutoScheduler -> ShowS
[AutoScheduler] -> ShowS
AutoScheduler -> FilePath
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
showList :: [AutoScheduler] -> ShowS
$cshowList :: [AutoScheduler] -> ShowS
show :: AutoScheduler -> FilePath
$cshow :: AutoScheduler -> FilePath
showsPrec :: Int -> AutoScheduler -> ShowS
$cshowsPrec :: Int -> AutoScheduler -> ShowS
Show)

loadAutoScheduler :: AutoScheduler -> IO ()
loadAutoScheduler :: AutoScheduler -> IO ()
loadAutoScheduler AutoScheduler
scheduler = do
  Maybe Text
lib <- IO (Maybe Text)
getHalideLibraryPath
  let prepare :: Text -> Text
prepare Text
s
        | Just Text
dir <- Maybe Text
lib = Text
dir forall a. Semigroup a => a -> a -> a
<> Text
"/lib" forall a. Semigroup a => a -> a -> a
<> Text
s forall a. Semigroup a => a -> a -> a
<> Text
".so"
        | Maybe Text
Nothing <- Maybe Text
lib = Text
"lib" forall a. Semigroup a => a -> a -> a
<> Text
s forall a. Semigroup a => a -> a -> a
<> Text
".so"
      path :: Text
path = Text -> Text
prepare forall a b. (a -> b) -> a -> b
$
        case AutoScheduler
scheduler of
          AutoScheduler
Adams2019 -> Text
"autoschedule_adams2019"
          AutoScheduler
Li2018 -> Text
"autoschedule_li2018"
          AutoScheduler
Mullapudi2016 -> Text
"autoschedule_mullapudi2016"
  Text -> IO ()
loadLibrary Text
path

applyAutoScheduler :: KnownNat n => AutoScheduler -> Target -> Func t n a -> IO Text
applyAutoScheduler :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
AutoScheduler -> Target -> Func t n a -> IO Text
applyAutoScheduler AutoScheduler
scheduler Target
target Func t n a
func = do
  let s :: ByteString
s = Text -> ByteString
encodeUtf8 forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> Text
pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> FilePath
show forall a b. (a -> b) -> a -> b
$ AutoScheduler
scheduler
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
    forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
t -> do
      Ptr CxxString -> IO Text
peekAndDeleteCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| std::string* {
              return handle_halide_exceptions([=](){
                auto name = std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)};
                auto pipeline = Halide::Pipeline{*$(Halide::Func* f)};
                auto params = Halide::AutoschedulerParams{name};
                auto results = pipeline.apply_autoscheduler(*$(Halide::Target* t), params);
                return new std::string{std::move(results.schedule_source)};
              });
            } |]

makeUnqualified :: Text -> Text
makeUnqualified :: Text -> Text
makeUnqualified = forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text -> (Text, Text)
T.breakOnEnd Text
"."

applySplit :: KnownNat n => Split -> Stage n a -> IO ()
applySplit :: forall (n :: Nat) a. KnownNat n => Split -> Stage n a -> IO ()
applySplit (SplitVar SplitContents
x) Stage n a
stage = do
  Expr Int32
oldVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified SplitContents
x.splitOld)
  Expr Int32
outerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified SplitContents
x.splitOuter)
  Expr Int32
innerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified SplitContents
x.splitInner)
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
TailStrategy
-> Expr Int32
-> (Expr Int32, Expr Int32)
-> Expr Int32
-> f n a
-> IO (f n a)
Language.Halide.Func.split SplitContents
x.splitTail Expr Int32
oldVar (Expr Int32
outerVar, Expr Int32
innerVar) SplitContents
x.splitFactor Stage n a
stage
applySplit (FuseVars FuseContents
x) Stage n a
stage = do
  Expr Int32
newVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified FuseContents
x.fuseNew)
  Expr Int32
innerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified FuseContents
x.fuseInner)
  Expr Int32
outerVar <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified FuseContents
x.fuseOuter)
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
(Expr Int32, Expr Int32) -> Expr Int32 -> f n a -> IO (f n a)
Language.Halide.Func.fuse (Expr Int32
innerVar, Expr Int32
outerVar) Expr Int32
newVar Stage n a
stage

applySplits :: KnownNat n => [Split] -> Stage n a -> IO ()
applySplits :: forall (n :: Nat) a. KnownNat n => [Split] -> Stage n a -> IO ()
applySplits [Split]
xs Stage n a
stage = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (n :: Nat) a. KnownNat n => Split -> Stage n a -> IO ()
`applySplit` Stage n a
stage) [Split]
xs

applyDim :: KnownNat n => Dim -> Stage n a -> IO ()
applyDim :: forall (n :: Nat) a. KnownNat n => Dim -> Stage n a -> IO ()
applyDim Dim
x Stage n a
stage = do
  Expr Int32
var' <- Text -> IO (Expr Int32)
mkVar (Text -> Text
makeUnqualified Dim
x.var)
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
    case Dim
x.forType of
      ForType
ForSerial -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Stage n a
stage
      ForType
ForParallel -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Int32 -> f n a -> IO (f n a)
parallel Expr Int32
var' Stage n a
stage
      ForType
ForVectorized -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Int32 -> f n a -> IO (f n a)
vectorize Expr Int32
var' Stage n a
stage
      ForType
ForUnrolled -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
Expr Int32 -> f n a -> IO (f n a)
unroll Expr Int32
var' Stage n a
stage
      ForType
ForExtern -> forall a. HasCallStack => FilePath -> a
error FilePath
"extern ForType is not yet supported by applyDim"
      ForType
ForGPUBlock -> forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuBlocks Dim
x.deviceApi Expr Int32
var' Stage n a
stage
      ForType
ForGPUThread -> forall (f :: Nat -> * -> *) (n :: Nat) a (k :: Nat).
(Schedulable f n a, KnownNat k, 1 <= k, k <= 3) =>
DeviceAPI -> IndexType k -> f n a -> IO (f n a)
gpuThreads Dim
x.deviceApi Expr Int32
var' Stage n a
stage
      ForType
ForGPULane -> forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
DeviceAPI -> Expr Int32 -> f n a -> IO (f n a)
gpuLanes Dim
x.deviceApi Expr Int32
var' Stage n a
stage

applyDims :: KnownNat n => [Dim] -> Stage n a -> IO ()
applyDims :: forall (n :: Nat) a. KnownNat n => [Dim] -> Stage n a -> IO ()
applyDims [Dim]
xs Stage n a
stage = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall (n :: Nat) a. KnownNat n => Dim -> Stage n a -> IO ()
`applyDim` Stage n a
stage) [Dim]
xs
  [Expr Int32]
vars <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Text -> IO (Expr Int32)
mkVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
makeUnqualified forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.var)) [Dim]
xs
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (f :: Nat -> * -> *) (n :: Nat) a.
Schedulable f n a =>
[Expr Int32] -> f n a -> IO (f n a)
reorder [Expr Int32]
vars Stage n a
stage

applySchedule :: KnownNat n => StageSchedule -> Stage n a -> IO ()
applySchedule :: forall (n :: Nat) a.
KnownNat n =>
StageSchedule -> Stage n a -> IO ()
applySchedule StageSchedule
schedule Stage n a
stage = do
  forall (n :: Nat) a. KnownNat n => [Split] -> Stage n a -> IO ()
applySplits StageSchedule
schedule.splits Stage n a
stage
  forall (n :: Nat) a. KnownNat n => [Dim] -> Stage n a -> IO ()
applyDims StageSchedule
schedule.dims Stage n a
stage

-- data SplitContents = SplitContents
--   { old :: !Text
--   , outer :: !Text
--   , inner :: !Text
--   , factor :: !(Maybe Int)
--   , exact :: !Bool
--   , tail :: !TailStrategy
--   }
--   deriving stock (Show, Eq)
--
--
-- applySplits :: [Split] -> Stage n a -> IO ()
-- applySplits splits stage =

-- v <-
--   [CU.exp| Halide::Internal::Dim* {
--        $(Halide::Internal::StageSchedule* schedule)->dims().data() } |]
-- putStrLn $ "n = " <> show n
-- mapM (\i -> print i >> peekElemOff v i) [0 .. n - 1]