{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      : Language.Halide.Trace
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Trace
  ( TraceEvent (..)
  , TraceEventCode (..)
  , TraceLoadStoreContents (..)
  , setCustomTrace
  , traceStores
  , traceLoads
  , collectIterationOrder
  )
where

import Control.Concurrent.MVar
import Control.Exception (bracket, bracket_)
import Data.ByteString (packCString)
import Data.Int (Int32)
import Data.Text (Text)
import Data.Text.Encoding (decodeUtf8)
import Foreign.Marshal (peekArray)
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr)
import Foreign.Storable
import GHC.TypeLits
import Language.C.Inline qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Dimension
import Language.Halide.Func
import Language.Halide.LoopLevel
import Language.Halide.Type
import Prelude hiding (min, tail)

-- | Haskell counterpart of [@halide_trace_event_code_t@](https://halide-lang.org/docs/_halide_runtime_8h.html#a485130f12eb8bb5fa5a9478eeb6b0dfa).
data TraceEventCode
  = TraceLoad
  | TraceStore
  | TraceBeginRealization
  | TraceEndRealization
  | TraceProduce
  | TraceEndProduce
  | TraceConsume
  | TraceEndConsume
  | TraceBeginPipeline
  | TraceEndPipeline
  | TraceTag
  deriving stock (Int -> TraceEventCode -> ShowS
[TraceEventCode] -> ShowS
TraceEventCode -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceEventCode] -> ShowS
$cshowList :: [TraceEventCode] -> ShowS
show :: TraceEventCode -> String
$cshow :: TraceEventCode -> String
showsPrec :: Int -> TraceEventCode -> ShowS
$cshowsPrec :: Int -> TraceEventCode -> ShowS
Show, TraceEventCode -> TraceEventCode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TraceEventCode -> TraceEventCode -> Bool
$c/= :: TraceEventCode -> TraceEventCode -> Bool
== :: TraceEventCode -> TraceEventCode -> Bool
$c== :: TraceEventCode -> TraceEventCode -> Bool
Eq, Eq TraceEventCode
TraceEventCode -> TraceEventCode -> Bool
TraceEventCode -> TraceEventCode -> Ordering
TraceEventCode -> TraceEventCode -> TraceEventCode
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TraceEventCode -> TraceEventCode -> TraceEventCode
$cmin :: TraceEventCode -> TraceEventCode -> TraceEventCode
max :: TraceEventCode -> TraceEventCode -> TraceEventCode
$cmax :: TraceEventCode -> TraceEventCode -> TraceEventCode
>= :: TraceEventCode -> TraceEventCode -> Bool
$c>= :: TraceEventCode -> TraceEventCode -> Bool
> :: TraceEventCode -> TraceEventCode -> Bool
$c> :: TraceEventCode -> TraceEventCode -> Bool
<= :: TraceEventCode -> TraceEventCode -> Bool
$c<= :: TraceEventCode -> TraceEventCode -> Bool
< :: TraceEventCode -> TraceEventCode -> Bool
$c< :: TraceEventCode -> TraceEventCode -> Bool
compare :: TraceEventCode -> TraceEventCode -> Ordering
$ccompare :: TraceEventCode -> TraceEventCode -> Ordering
Ord)

data TraceLoadStoreContents = TraceLoadStoreContents
  { TraceLoadStoreContents -> Ptr ()
valuePtr :: !(Ptr ())
  , TraceLoadStoreContents -> HalideType
valueType :: !HalideType
  , TraceLoadStoreContents -> [Int]
coordinates :: ![Int]
  }
  deriving stock (Int -> TraceLoadStoreContents -> ShowS
[TraceLoadStoreContents] -> ShowS
TraceLoadStoreContents -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceLoadStoreContents] -> ShowS
$cshowList :: [TraceLoadStoreContents] -> ShowS
show :: TraceLoadStoreContents -> String
$cshow :: TraceLoadStoreContents -> String
showsPrec :: Int -> TraceLoadStoreContents -> ShowS
$cshowsPrec :: Int -> TraceLoadStoreContents -> ShowS
Show)

data TraceEvent = TraceEvent
  { TraceEvent -> Text
funcName :: !Text
  , TraceEvent -> TraceEventCode
eventCode :: !TraceEventCode
  , TraceEvent -> Maybe TraceLoadStoreContents
loadStoreContents :: !(Maybe TraceLoadStoreContents)
  }
  deriving stock (Int -> TraceEvent -> ShowS
[TraceEvent] -> ShowS
TraceEvent -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TraceEvent] -> ShowS
$cshowList :: [TraceEvent] -> ShowS
show :: TraceEvent -> String
$cshow :: TraceEvent -> String
showsPrec :: Int -> TraceEvent -> ShowS
$cshowsPrec :: Int -> TraceEvent -> ShowS
Show)

importHalide

instance Enum TraceEventCode where
  fromEnum :: TraceEventCode -> Int
fromEnum =
    forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      TraceEventCode
TraceLoad -> [CU.pure| int { halide_trace_load } |]
      TraceEventCode
TraceStore -> [CU.pure| int { halide_trace_store } |]
      TraceEventCode
TraceBeginRealization -> [CU.pure| int { halide_trace_begin_realization } |]
      TraceEventCode
TraceEndRealization -> [CU.pure| int { halide_trace_end_realization } |]
      TraceEventCode
TraceProduce -> [CU.pure| int { halide_trace_produce } |]
      TraceEventCode
TraceEndProduce -> [CU.pure| int { halide_trace_end_produce } |]
      TraceEventCode
TraceConsume -> [CU.pure| int { halide_trace_consume } |]
      TraceEventCode
TraceEndConsume -> [CU.pure| int { halide_trace_end_consume } |]
      TraceEventCode
TraceBeginPipeline -> [CU.pure| int { halide_trace_begin_pipeline } |]
      TraceEventCode
TraceEndPipeline -> [CU.pure| int { halide_trace_end_pipeline } |]
      TraceEventCode
TraceTag -> [CU.pure| int { halide_trace_tag } |]
  toEnum :: Int -> TraceEventCode
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_load } |] = TraceEventCode
TraceLoad
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_store } |] = TraceEventCode
TraceStore
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_begin_realization } |] = TraceEventCode
TraceBeginRealization
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_realization } |] = TraceEventCode
TraceEndRealization
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_produce } |] = TraceEventCode
TraceProduce
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_produce } |] = TraceEventCode
TraceEndProduce
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_consume } |] = TraceEventCode
TraceConsume
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_consume } |] = TraceEventCode
TraceEndConsume
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_begin_pipeline } |] = TraceEventCode
TraceBeginPipeline
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_end_pipeline } |] = TraceEventCode
TraceEndPipeline
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { halide_trace_tag } |] = TraceEventCode
TraceTag
    | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid TraceEventCode: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k

peekTraceLoadStoreContents :: Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents :: Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents Ptr TraceEvent
p = do
  Ptr ()
v <- [CU.exp| void* { $(const halide_trace_event_t* p)->value } |]
  HalideType
tp <- forall a. Storable a => Ptr a -> IO a
peek forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const halide_type_t* { &$(const halide_trace_event_t* p)->type } |]
  Int
n <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_trace_event_t* p)->dimensions } |]
  [Int32]
cs <- forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
n forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const int32_t* { $(const halide_trace_event_t* p)->coordinates } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Ptr () -> HalideType -> [Int] -> TraceLoadStoreContents
TraceLoadStoreContents Ptr ()
v HalideType
tp (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int32]
cs)

peekTraceEvent :: Ptr TraceEvent -> IO TraceEvent
peekTraceEvent :: Ptr TraceEvent -> IO TraceEvent
peekTraceEvent Ptr TraceEvent
p = do
  Text
f <-
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Text
decodeUtf8 forall a b. (a -> b) -> a -> b
$
      CString -> IO ByteString
packCString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const char* { $(const halide_trace_event_t* p)->func } |]
  TraceEventCode
c <- forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [CU.exp| int { $(const halide_trace_event_t* p)->event } |]
  Maybe TraceLoadStoreContents
contents <-
    case TraceEventCode
c of
      TraceEventCode
TraceLoad -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents Ptr TraceEvent
p
      TraceEventCode
TraceStore -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr TraceEvent -> IO TraceLoadStoreContents
peekTraceLoadStoreContents Ptr TraceEvent
p
      TraceEventCode
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text
-> TraceEventCode -> Maybe TraceLoadStoreContents -> TraceEvent
TraceEvent Text
f TraceEventCode
c Maybe TraceLoadStoreContents
contents

withTrace
  :: (TraceEvent -> IO ()) -> (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32) -> IO a) -> IO a
withTrace :: forall a.
(TraceEvent -> IO ())
-> (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
    -> IO a)
-> IO a
withTrace TraceEvent -> IO ()
customTrace = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32))
allocate forall {a}. FunPtr a -> IO ()
destroy
  where
    allocate :: IO (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32))
allocate = do
      $(C.mkFunPtr [t|Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32|]) forall a b. (a -> b) -> a -> b
$ \Ptr CxxUserContext
_ Ptr TraceEvent
p ->
        Ptr TraceEvent -> IO TraceEvent
peekTraceEvent Ptr TraceEvent
p forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TraceEvent -> IO ()
customTrace forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int32
0
    destroy :: FunPtr a -> IO ()
destroy = forall {a}. FunPtr a -> IO ()
freeHaskellFunPtr

setCustomTrace
  :: KnownNat n
  => (TraceEvent -> IO ())
  -- ^ Custom trace function
  -> Func t n a
  -- ^ For which func to enable it
  -> IO b
  -- ^ For the duration of which computation to enable it
  -> IO b
setCustomTrace :: forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
(TraceEvent -> IO ()) -> Func t n a -> IO b -> IO b
setCustomTrace TraceEvent -> IO ()
customTrace Func t n a
f IO b
action =
  forall a.
(TraceEvent -> IO ())
-> (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
    -> IO a)
-> IO a
withTrace TraceEvent -> IO ()
customTrace forall a b. (a -> b) -> a -> b
$ \FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
tracePtr ->
    forall a b c. IO a -> IO b -> IO c -> IO c
bracket_ (FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32) -> IO ()
set FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
tracePtr) IO ()
unset IO b
action
  where
    set :: FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32) -> IO ()
set FunPtr (Ptr CxxUserContext -> Ptr TraceEvent -> IO Int32)
tracePtr =
      forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
        [CU.block| void {
          auto& func = *$(Halide::Func* f');
          func.jit_handlers().custom_trace = $(int32_t (*tracePtr)(Halide::JITUserContext*, const halide_trace_event_t*));
        } |]
    unset :: IO ()
unset =
      forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
        [CU.block| void {
          auto& func = *$(Halide::Func* f');
          func.jit_handlers().custom_trace = nullptr;
        } |]

traceStores :: KnownNat n => Func t n a -> IO (Func t n a)
traceStores :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
traceStores Func t n a
f = do
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
    [CU.exp| void { $(Halide::Func* f')->trace_stores() } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
f

traceLoads :: KnownNat n => Func t n a -> IO (Func t n a)
traceLoads :: forall (n :: Nat) (t :: FuncTy) a.
KnownNat n =>
Func t n a -> IO (Func t n a)
traceLoads Func t n a
f = do
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
f forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f' ->
    [CU.exp| void { $(Halide::Func* f')->trace_loads() } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Func t n a
f

collectIterationOrder
  :: KnownNat n
  => (TraceEventCode -> Bool)
  -> Func t n a
  -> IO b
  -> IO ([[Int]], b)
collectIterationOrder :: forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
(TraceEventCode -> Bool) -> Func t n a -> IO b -> IO ([[Int]], b)
collectIterationOrder TraceEventCode -> Bool
cond Func t n a
f IO b
action = do
  MVar [[Int]]
m <- forall a. a -> IO (MVar a)
newMVar []
  let tracer :: TraceEvent -> IO ()
tracer (TraceEvent Text
_ TraceEventCode
c' (Just TraceLoadStoreContents
payload))
        | TraceEventCode -> Bool
cond TraceEventCode
c' = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar [[Int]]
m forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TraceLoadStoreContents
payload.coordinates :)
      tracer TraceEvent
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  forall (n :: Nat) (t :: FuncTy) a b.
KnownNat n =>
(TraceEvent -> IO ()) -> Func t n a -> IO b -> IO b
setCustomTrace TraceEvent -> IO ()
tracer Func t n a
f forall a b. (a -> b) -> a -> b
$ do
    b
r <- IO b
action
    [[Int]]
cs <- forall a. MVar a -> IO a
readMVar MVar [[Int]]
m
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. [a] -> [a]
reverse [[Int]]
cs, b
r)