{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
module Language.Halide.LoopLevel
( LoopLevel (..)
, LoopLevelTy (..)
, SomeLoopLevel (..)
, LoopAlignStrategy (..)
, CxxLoopLevel
, withCxxLoopLevel
, wrapCxxLoopLevel
)
where
import Control.Exception (bracket)
import Data.Text (Text)
import Foreign.ForeignPtr
import Foreign.Marshal (toBool)
import Foreign.Ptr (Ptr)
import GHC.Records (HasField (..))
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Cpp.Exception as C
import qualified Language.C.Inline.Unsafe as CU
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (min, tail)
data CxxLoopLevel
importHalide
data LoopLevelTy = InlinedTy | RootTy | LockedTy
data LoopLevel (t :: LoopLevelTy) where
InlinedLoopLevel :: LoopLevel 'InlinedTy
RootLoopLevel :: LoopLevel 'RootTy
LoopLevel :: !(ForeignPtr CxxLoopLevel) -> LoopLevel 'LockedTy
data SomeLoopLevel where
SomeLoopLevel :: LoopLevel t -> SomeLoopLevel
deriving stock instance Show SomeLoopLevel
instance Eq SomeLoopLevel where
(SomeLoopLevel LoopLevel t
InlinedLoopLevel) == :: SomeLoopLevel -> SomeLoopLevel -> Bool
== (SomeLoopLevel LoopLevel t
InlinedLoopLevel) = Bool
True
(SomeLoopLevel LoopLevel t
RootLoopLevel) == (SomeLoopLevel LoopLevel t
RootLoopLevel) = Bool
True
(SomeLoopLevel a :: LoopLevel t
a@(LoopLevel ForeignPtr CxxLoopLevel
_)) == (SomeLoopLevel b :: LoopLevel t
b@(LoopLevel ForeignPtr CxxLoopLevel
_)) = LoopLevel t
a forall a. Eq a => a -> a -> Bool
== LoopLevel t
b
SomeLoopLevel
_ == SomeLoopLevel
_ = Bool
False
instance Eq (LoopLevel t) where
LoopLevel t
level1 == :: LoopLevel t -> LoopLevel t -> Bool
== LoopLevel t
level2 =
forall a. (Eq a, Num a) => a -> Bool
toBool forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level1 forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l1 ->
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level2 forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l2 ->
[CU.exp| bool { *$(const Halide::LoopLevel* l1) == *$(const Halide::LoopLevel* l2) } |]
instance Show (LoopLevel t) where
showsPrec :: Int -> LoopLevel t -> ShowS
showsPrec Int
_ LoopLevel t
InlinedLoopLevel = String -> ShowS
showString String
"InlinedLoopLevel"
showsPrec Int
_ LoopLevel t
RootLoopLevel = String -> ShowS
showString String
"RootLoopLevel"
showsPrec Int
d level :: LoopLevel t
level@(LoopLevel ForeignPtr CxxLoopLevel
_) =
Bool -> ShowS -> ShowS
showParen (Int
d forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$
String -> ShowS
showString String
"LoopLevel {func = "
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows (LoopLevel t
level.func :: Text)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
", var = "
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows (LoopLevel t
level.var :: Expr Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"}"
data LoopAlignStrategy
=
LoopAlignStart
|
LoopAlignEnd
|
LoopNoAlign
|
LoopAlignAuto
deriving stock (Int -> LoopAlignStrategy -> ShowS
[LoopAlignStrategy] -> ShowS
LoopAlignStrategy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopAlignStrategy] -> ShowS
$cshowList :: [LoopAlignStrategy] -> ShowS
show :: LoopAlignStrategy -> String
$cshow :: LoopAlignStrategy -> String
showsPrec :: Int -> LoopAlignStrategy -> ShowS
$cshowsPrec :: Int -> LoopAlignStrategy -> ShowS
Show, LoopAlignStrategy -> LoopAlignStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c/= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
== :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c== :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
Eq, Eq LoopAlignStrategy
LoopAlignStrategy -> LoopAlignStrategy -> Bool
LoopAlignStrategy -> LoopAlignStrategy -> Ordering
LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
$cmin :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
max :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
$cmax :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
>= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c>= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
> :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c> :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
<= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c<= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
< :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c< :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
compare :: LoopAlignStrategy -> LoopAlignStrategy -> Ordering
$ccompare :: LoopAlignStrategy -> LoopAlignStrategy -> Ordering
Ord)
instance Enum LoopAlignStrategy where
fromEnum :: LoopAlignStrategy -> Int
fromEnum =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
LoopAlignStrategy
LoopAlignStart -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignStart) } |]
LoopAlignStrategy
LoopAlignEnd -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignEnd) } |]
LoopAlignStrategy
LoopNoAlign -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::NoAlign) } |]
LoopAlignStrategy
LoopAlignAuto -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::Auto) } |]
toEnum :: Int -> LoopAlignStrategy
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignStart) } |] = LoopAlignStrategy
LoopAlignStart
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignEnd) } |] = LoopAlignStrategy
LoopAlignEnd
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::NoAlign) } |] = LoopAlignStrategy
LoopNoAlign
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::Auto) } |] = LoopAlignStrategy
LoopAlignAuto
| Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid LoopAlignStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
isInlined :: LoopLevel t -> Bool
isInlined :: forall (t :: LoopLevelTy). LoopLevel t -> Bool
isInlined LoopLevel t
InlinedLoopLevel = Bool
True
isInlined LoopLevel t
_ = Bool
False
isRoot :: LoopLevel t -> Bool
isRoot :: forall (t :: LoopLevelTy). LoopLevel t -> Bool
isRoot LoopLevel t
RootLoopLevel = Bool
True
isRoot LoopLevel t
_ = Bool
False
instance HasField "func" (LoopLevel 'LockedTy) Text where
getField :: LoopLevel 'LockedTy -> Text
getField LoopLevel 'LockedTy
level = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel 'LockedTy
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
Ptr CxxString -> IO Text
peekAndDeleteCxxString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* {
new std::string{$(const Halide::LoopLevel* level')->func()} } |]
instance HasField "var" (LoopLevel 'LockedTy) (Expr Int32) where
getField :: LoopLevel 'LockedTy -> Expr Int32
getField LoopLevel 'LockedTy
level = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel 'LockedTy
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
Ptr CxxVarOrRVar -> IO (Expr Int32)
wrapCxxVarOrRVar
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::VarOrRVar* {
new Halide::VarOrRVar{$(const Halide::LoopLevel* level')->var()} } |]
wrapCxxLoopLevel :: Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel :: Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel Ptr CxxLoopLevel
p = do
[C.throwBlock| void { handle_halide_exceptions([=]() { $(Halide::LoopLevel* p)->lock(); }); } |]
Bool
inlined <-
forall a. (Eq a, Num a) => a -> Bool
toBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.throwBlock| bool {
return handle_halide_exceptions([=](){
return $(const Halide::LoopLevel* p)->is_inlined(); });
} |]
Bool
root <-
forall a. (Eq a, Num a) => a -> Bool
toBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.throwBlock| bool {
return handle_halide_exceptions([=](){
return $(const Halide::LoopLevel* p)->is_root(); });
} |]
let level :: IO SomeLoopLevel
level
| Bool
inlined = [CU.exp| void { delete $(Halide::LoopLevel *p) } |] forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (t :: LoopLevelTy). LoopLevel t -> SomeLoopLevel
SomeLoopLevel LoopLevel 'InlinedTy
InlinedLoopLevel)
| Bool
root = [CU.exp| void { delete $(Halide::LoopLevel *p) } |] forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (t :: LoopLevelTy). LoopLevel t -> SomeLoopLevel
SomeLoopLevel LoopLevel 'RootTy
RootLoopLevel)
| Bool
otherwise = do
let deleter :: FunPtr (Ptr CxxLoopLevel -> IO ())
deleter = [C.funPtr| void deleteLoopLevel(Halide::LoopLevel* p) { delete p; } |]
forall (t :: LoopLevelTy). LoopLevel t -> SomeLoopLevel
SomeLoopLevel forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr CxxLoopLevel -> LoopLevel 'LockedTy
LoopLevel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxLoopLevel -> IO ())
deleter Ptr CxxLoopLevel
p
IO SomeLoopLevel
level
withCxxLoopLevel :: LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel :: forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel (LoopLevel ForeignPtr CxxLoopLevel
fp) Ptr CxxLoopLevel -> IO a
action = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxLoopLevel
fp Ptr CxxLoopLevel -> IO a
action
withCxxLoopLevel LoopLevel t
level Ptr CxxLoopLevel -> IO a
action = do
let allocate :: IO (Ptr CxxLoopLevel)
allocate
| forall (t :: LoopLevelTy). LoopLevel t -> Bool
isInlined LoopLevel t
level = [CU.exp| Halide::LoopLevel* { new Halide::LoopLevel{Halide::LoopLevel::inlined()} } |]
| forall (t :: LoopLevelTy). LoopLevel t -> Bool
isRoot LoopLevel t
level = [CU.exp| Halide::LoopLevel* { new Halide::LoopLevel{Halide::LoopLevel::root()} } |]
| Bool
otherwise = forall a. HasCallStack => String -> a
error String
"this should never happen"
destroy :: Ptr CxxLoopLevel -> IO ()
destroy Ptr CxxLoopLevel
p = [CU.exp| void { delete $(Halide::LoopLevel *p) } |]
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr CxxLoopLevel)
allocate Ptr CxxLoopLevel -> IO ()
destroy Ptr CxxLoopLevel -> IO a
action