{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Language.Halide.Trace
( TraceEvent (..)
, TraceEventCode (..)
, TraceLoadStoreContents (..)
, setCustomTrace
, traceStores
, traceLoads
, collectIterationOrder
)
where
import Control.Concurrent.MVar
import Control.Exception (bracket, bracket_)
import Data.ByteString (packCString)
import Data.Int (Int32)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import Foreign.Marshal (peekArray)
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr)
import Foreign.Storable
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Dimension
import Language.Halide.Func
import Language.Halide.LoopLevel
import Language.Halide.Type
import Prelude hiding (min, tail)
data TraceEventCode
= TraceLoad
| TraceStore
| TraceBeginRealization
| TraceEndRealization
| TraceProduce
| TraceEndProduce
| TraceConsume
| TraceEndConsume
| TraceBeginPipeline
| TraceEndPipeline
| TraceTag
deriving stock (Int -> TraceEventCode -> ShowS
[TraceEventCode] -> ShowS
TraceEventCode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceEventCode] -> ShowS
$cshowList :: [TraceEventCode] -> ShowS
show :: TraceEventCode -> String
$cshow :: TraceEventCode -> String
showsPrec :: Int -> TraceEventCode -> ShowS
$cshowsPrec :: Int -> TraceEventCode -> ShowS
Show, TraceEventCode -> TraceEventCode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TraceEventCode -> TraceEventCode -> Bool
$c/= :: TraceEventCode -> TraceEventCode -> Bool
== :: TraceEventCode -> TraceEventCode -> Bool
$c== :: TraceEventCode -> TraceEventCode -> Bool
Eq, Eq TraceEventCode
TraceEventCode -> TraceEventCode -> Bool
TraceEventCode -> TraceEventCode -> Ordering
TraceEventCode -> TraceEventCode -> TraceEventCode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TraceEventCode -> TraceEventCode -> TraceEventCode
$cmin :: TraceEventCode -> TraceEventCode -> TraceEventCode
max :: TraceEventCode -> TraceEventCode -> TraceEventCode
$cmax :: TraceEventCode -> TraceEventCode -> TraceEventCode
>= :: TraceEventCode -> TraceEventCode -> Bool
$c>= :: TraceEventCode -> TraceEventCode -> Bool
> :: TraceEventCode -> TraceEventCode -> Bool
$c> :: TraceEventCode -> TraceEventCode -> Bool
<= :: TraceEventCode -> TraceEventCode -> Bool
$c<= :: TraceEventCode -> TraceEventCode -> Bool
< :: TraceEventCode -> TraceEventCode -> Bool
$c< :: TraceEventCode -> TraceEventCode -> Bool
compare :: TraceEventCode -> TraceEventCode -> Ordering
$ccompare :: TraceEventCode -> TraceEventCode -> Ordering
Ord)
data TraceLoadStoreContents = TraceLoadStoreContents
{ TraceLoadStoreContents -> Ptr ()
valuePtr :: !(Ptr ())
, TraceLoadStoreContents -> HalideType
valueType :: !HalideType
, TraceLoadStoreContents -> [Int]
coordinates :: ![Int]
}
deriving stock (Int -> TraceLoadStoreContents -> ShowS
[TraceLoadStoreContents] -> ShowS
TraceLoadStoreContents -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceLoadStoreContents] -> ShowS
$cshowList :: [TraceLoadStoreContents] -> ShowS
show :: TraceLoadStoreContents -> String
$cshow :: TraceLoadStoreContents -> String
showsPrec :: Int -> TraceLoadStoreContents -> ShowS
$cshowsPrec :: Int -> TraceLoadStoreContents -> ShowS
Show)
data TraceEvent = TraceEvent
{ TraceEvent -> Text
funcName :: !Text
, TraceEvent -> TraceEventCode
eventCode :: !TraceEventCode
, TraceEvent -> Maybe TraceLoadStoreContents
loadStoreContents :: !(Maybe TraceLoadStoreContents)
}
deriving stock (Int -> TraceEvent -> ShowS
[TraceEvent] -> ShowS
TraceEvent -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceEvent] -> ShowS
$cshowList :: [TraceEvent] -> ShowS
show :: TraceEvent -> String
$cshow :: TraceEvent -> String
showsPrec :: Int -> TraceEvent -> ShowS
$cshowsPrec :: Int -> TraceEvent -> ShowS
Show)
importHalide
instance Enum TraceEventCode where
fromEnum :: TraceEventCode -> Int
fromEnum =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
TraceEventCode
TraceLoad -> [CU.pure| int { halide_trace_load } |]
TraceEventCode
TraceStore -> [CU.pure| int { halide_trace_store } |]
TraceEventCode
TraceBeginRealization -> [CU.pure| int { halide_trace_begin_realization } |]
TraceEventCode
TraceEndRealization -> [CU.pure| int { halide_trace_end_realization } |]
TraceEventCode
TraceProduce -> [CU.pure| int { halide_trace_produce } |]
TraceEventCode
TraceEndProduce -> [CU.pure| int { halide_trace_end_produce } |]
TraceEventCode
TraceConsume -> [CU.pure| int { halide_trace_consume } |]
TraceEventCode
TraceEndConsume -> [CU.pure| int { halide_trace_end_consume } |]
TraceEventCode
TraceBeginPipeline -> [CU.pure| int { halide_trace_begin_pipeline } |]
TraceEventCode
TraceEndPipeline -> [CU.pure| int { halide_trace_end_pipeline } |]
TraceEventCode
TraceTag -> [CU.pure| int { halide_trace_tag } |]
toEnum :: Int -> TraceEventCode
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_load } |] = TraceEventCode
TraceLoad
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_store } |] = TraceEventCode
TraceStore
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_begin_realization } |] = TraceEventCode
TraceBeginRealization
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_realization } |] = TraceEventCode
TraceEndRealization
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_produce } |] = TraceEventCode
TraceProduce
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_produce } |] = TraceEventCode
TraceEndProduce
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_consume } |] = TraceEventCode
TraceConsume
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_consume } |] = TraceEventCode
TraceEndConsume
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_begin_pipeline } |] = TraceEventCode
TraceBeginPipeline
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_pipeline } |] = TraceEventCode
TraceEndPipeline
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_tag } |] = TraceEventCode
TraceTag
| Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid TraceEventCode: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
peekTraceLoadStoreContents :: Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents :: Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents Ptr TraceEvent
p = do
Ptr ()
v <- [CU.exp| void* { $(const halide_trace_event_t* p)->value } |]
HalideType
tp <- forall a. Storable a => Ptr a -> IO a
peek forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const halide_type_t* { &$(const halide_trace_event_t* p)->type } |]
Int
n <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_trace_event_t* p)->dimensions } |]
[Int32]
cs <- forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
n forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const int32_t* { $(const halide_trace_event_t* p)->coordinates } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Ptr () -> HalideType -> [Int] -> TraceLoadStoreContents
TraceLoadStoreContents Ptr ()
v HalideType
tp (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int32]
cs)
peekTraceEvent :: Ptr TraceEvent -> IO TraceEvent
peekTraceEvent :: Ptr TraceEvent -> IO TraceEvent
peekTraceEvent Ptr TraceEvent
p = do
Text
f <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Text
decodeUtf8 forall a b. (a -> b) -> a -> b
$
CString -> IO ByteString
packCString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const char* { $(const halide_trace_event_t* p)->func } |]
TraceEventCode
c <- forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_trace_event_t* p)->event } |]
Maybe TraceLoadStoreContents
contents <-
case TraceEventCode
c of
TraceEventCode
TraceLoad -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents Ptr TraceEvent
p
TraceEventCode
TraceStore -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents Ptr TraceEvent
p
TraceEventCode
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text
-> TraceEventCode -> Maybe TraceLoadStoreContents -> TraceEvent
TraceEvent Text
f TraceEventCode
c Maybe TraceLoadStoreContents
contents
withTrace
:: (TraceEvent -> IO ()) -> (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32) -> IO a) -> IO a
withTrace :: forall a.
(TraceEvent -> IO ())
-> (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
-> IO a)
-> IO a
withTrace TraceEvent -> IO ()
customTrace = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32))
allocate forall {a}. FunPtr a -> IO ()
destroy
where
allocate :: IO (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32))
allocate = do
$(C.mkFunPtr [t|Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32|]) forall a b. (a -> b) -> a -> b
$ \Ptr CxxUserContext
_ Ptr TraceEvent
p ->
Ptr TraceEvent -> IO TraceEvent
peekTraceEvent Ptr TraceEvent
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TraceEvent -> IO ()
customTrace forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int32
0
destroy :: FunPtr a -> IO ()
destroy = forall {a}. FunPtr a -> IO ()
freeHaskellFunPtr
setCustomTrace
:: KnownNat n
=> (TraceEvent -> IO ())
-> Func t n a
-> IO b
-> IO b
setCustomTrace :: forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
(TraceEvent -> IO ()) -> Func t n a -> IO b -> IO b
setCustomTrace TraceEvent -> IO ()
customTrace Func t n a
f IO b
action =
forall a.
(TraceEvent -> IO ())
-> (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
-> IO a)
-> IO a
withTrace TraceEvent -> IO ()
customTrace forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
tracePtr ->
forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32) -> IO ()
set FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
tracePtr) IO ()
unset IO b
action
where
set :: FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32) -> IO ()
set FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
tracePtr =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
[CU.block| void {
auto& func = *$(Halide::Func* f');
func.jit_handlers().custom_trace = $(int32_t (*tracePtr)(Halide::JITUserContext*, const halide_trace_event_t*));
} |]
unset :: IO ()
unset =
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
[CU.block| void {
auto& func = *$(Halide::Func* f');
func.jit_handlers().custom_trace = nullptr;
} |]
traceStores :: KnownNat n => Func t n a -> IO (Func t n a)
traceStores :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
traceStores Func t n a
f = do
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
[CU.exp| void { $(Halide::Func* f')->trace_stores() } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
f
traceLoads :: KnownNat n => Func t n a -> IO (Func t n a)
traceLoads :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
traceLoads Func t n a
f = do
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
[CU.exp| void { $(Halide::Func* f')->trace_loads() } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
f
collectIterationOrder
:: KnownNat n
=> (TraceEventCode -> Bool)
-> Func t n a
-> IO b
-> IO ([[Int]], b)
collectIterationOrder :: forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
(TraceEventCode -> Bool) -> Func t n a -> IO b -> IO ([[Int]], b)
collectIterationOrder TraceEventCode -> Bool
cond Func t n a
f IO b
action = do
MVar [[Int]]
m <- forall a. a -> IO (MVar a)
newMVar []
let tracer :: TraceEvent -> IO ()
tracer (TraceEvent Text
_ TraceEventCode
c' (Just TraceLoadStoreContents
payload))
| TraceEventCode -> Bool
cond TraceEventCode
c' = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar [[Int]]
m forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TraceLoadStoreContents
payload.coordinates :)
tracer TraceEvent
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
(TraceEvent -> IO ()) -> Func t n a -> IO b -> IO b
setCustomTrace TraceEvent -> IO ()
tracer Func t n a
f forall a b. (a -> b) -> a -> b
$ do
b
r <- IO b
action
[[Int]]
cs <- forall a. MVar a -> IO a
readMVar MVar [[Int]]
m
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. [a] -> [a]
reverse [[Int]]
cs, b
r)