{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

-- |
-- Module      : Language.Halide.LoopLevel
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.LoopLevel
  ( LoopLevel (..)
  , LoopLevelTy (..)
  , SomeLoopLevel (..)
  , LoopAlignStrategy (..)

    -- * Internal
  , CxxLoopLevel
  , withCxxLoopLevel
  , wrapCxxLoopLevel
  )
where

import Control.Exception (bracket)
import Data.Text (Text)
import Foreign.ForeignPtr
import Foreign.Marshal (toBool)
import Foreign.Ptr (Ptr)
import GHC.Records (HasField (..))
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Cpp.Exception as C
import qualified Language.C.Inline.Unsafe as CU
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Type
import Language.Halide.Utils
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (min, tail)

-- | Haskell counterpart of @Halide::LoopLevel@
data CxxLoopLevel

importHalide

data LoopLevelTy = InlinedTy | RootTy | LockedTy

-- | A reference to a site in a Halide statement at the top of the body of a particular for loop.
data LoopLevel (t :: LoopLevelTy) where
  InlinedLoopLevel :: LoopLevel 'InlinedTy
  RootLoopLevel :: LoopLevel 'RootTy
  LoopLevel :: !(ForeignPtr CxxLoopLevel) -> LoopLevel 'LockedTy

data SomeLoopLevel where
  SomeLoopLevel :: LoopLevel t -> SomeLoopLevel

deriving stock instance Show SomeLoopLevel

instance Eq SomeLoopLevel where
  (SomeLoopLevel LoopLevel t
InlinedLoopLevel) == :: SomeLoopLevel -> SomeLoopLevel -> Bool
== (SomeLoopLevel LoopLevel t
InlinedLoopLevel) = Bool
True
  (SomeLoopLevel LoopLevel t
RootLoopLevel) == (SomeLoopLevel LoopLevel t
RootLoopLevel) = Bool
True
  (SomeLoopLevel a :: LoopLevel t
a@(LoopLevel ForeignPtr CxxLoopLevel
_)) == (SomeLoopLevel b :: LoopLevel t
b@(LoopLevel ForeignPtr CxxLoopLevel
_)) = LoopLevel t
a forall a. Eq a => a -> a -> Bool
== LoopLevel t
b
  SomeLoopLevel
_ == SomeLoopLevel
_ = Bool
False

instance Eq (LoopLevel t) where
  LoopLevel t
level1 == :: LoopLevel t -> LoopLevel t -> Bool
== LoopLevel t
level2 =
    forall a. (Eq a, Num a) => a -> Bool
toBool forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
      forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level1 forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l1 ->
        forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel t
level2 forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
l2 ->
          [CU.exp| bool { *$(const Halide::LoopLevel* l1) == *$(const Halide::LoopLevel* l2) } |]

instance Show (LoopLevel t) where
  showsPrec :: Int -> LoopLevel t -> ShowS
showsPrec Int
_ LoopLevel t
InlinedLoopLevel = String -> ShowS
showString String
"InlinedLoopLevel"
  showsPrec Int
_ LoopLevel t
RootLoopLevel = String -> ShowS
showString String
"RootLoopLevel"
  showsPrec Int
d level :: LoopLevel t
level@(LoopLevel ForeignPtr CxxLoopLevel
_) =
    Bool -> ShowS -> ShowS
showParen (Int
d forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$
      String -> ShowS
showString String
"LoopLevel {func = "
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows (LoopLevel t
level.func :: Text)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
", var = "
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows (LoopLevel t
level.var :: Expr Int32)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
"}"

-- desc
--   where
--     desc = unpack . unsafePerformIO $
--       withCxxLoopLevel level $ \l ->
--         peekAndDeleteCxxString
--           =<< [C.throwBlock| std::string* {
--                 return handle_halide_exceptions([=](){
--                   return new std::string{$(const Halide::LoopLevel* l)->to_string()};
--                 });
--               } |]

-- | Different ways to handle the case when the start/end of the loops of stages computed with (fused)
-- are not aligned.
data LoopAlignStrategy
  = -- | Shift the start of the fused loops to align.
    LoopAlignStart
  | -- | Shift the end of the fused loops to align.
    LoopAlignEnd
  | -- | 'computeWith' will make no attempt to align the start/end of the fused loops.
    LoopNoAlign
  | -- | By default, LoopAlignStrategy is set to 'LoopNoAlign'.
    LoopAlignAuto
  deriving stock (Int -> LoopAlignStrategy -> ShowS
[LoopAlignStrategy] -> ShowS
LoopAlignStrategy -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LoopAlignStrategy] -> ShowS
$cshowList :: [LoopAlignStrategy] -> ShowS
show :: LoopAlignStrategy -> String
$cshow :: LoopAlignStrategy -> String
showsPrec :: Int -> LoopAlignStrategy -> ShowS
$cshowsPrec :: Int -> LoopAlignStrategy -> ShowS
Show, LoopAlignStrategy -> LoopAlignStrategy -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c/= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
== :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c== :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
Eq, Eq LoopAlignStrategy
LoopAlignStrategy -> LoopAlignStrategy -> Bool
LoopAlignStrategy -> LoopAlignStrategy -> Ordering
LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
$cmin :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
max :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
$cmax :: LoopAlignStrategy -> LoopAlignStrategy -> LoopAlignStrategy
>= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c>= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
> :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c> :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
<= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c<= :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
< :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
$c< :: LoopAlignStrategy -> LoopAlignStrategy -> Bool
compare :: LoopAlignStrategy -> LoopAlignStrategy -> Ordering
$ccompare :: LoopAlignStrategy -> LoopAlignStrategy -> Ordering
Ord)

instance Enum LoopAlignStrategy where
  fromEnum :: LoopAlignStrategy -> Int
fromEnum =
    forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      LoopAlignStrategy
LoopAlignStart -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignStart) } |]
      LoopAlignStrategy
LoopAlignEnd -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignEnd) } |]
      LoopAlignStrategy
LoopNoAlign -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::NoAlign) } |]
      LoopAlignStrategy
LoopAlignAuto -> [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::Auto) } |]
  toEnum :: Int -> LoopAlignStrategy
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignStart) } |] = LoopAlignStrategy
LoopAlignStart
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::AlignEnd) } |] = LoopAlignStrategy
LoopAlignEnd
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::NoAlign) } |] = LoopAlignStrategy
LoopNoAlign
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::LoopAlignStrategy::Auto) } |] = LoopAlignStrategy
LoopAlignAuto
    | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid LoopAlignStrategy: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k

isInlined :: LoopLevel t -> Bool
isInlined :: forall (t :: LoopLevelTy). LoopLevel t -> Bool
isInlined LoopLevel t
InlinedLoopLevel = Bool
True
isInlined LoopLevel t
_ = Bool
False

isRoot :: LoopLevel t -> Bool
isRoot :: forall (t :: LoopLevelTy). LoopLevel t -> Bool
isRoot LoopLevel t
RootLoopLevel = Bool
True
isRoot LoopLevel t
_ = Bool
False

instance HasField "func" (LoopLevel 'LockedTy) Text where
  getField :: LoopLevel 'LockedTy -> Text
getField LoopLevel 'LockedTy
level = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel 'LockedTy
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
      Ptr CxxString -> IO Text
peekAndDeleteCxxString
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| std::string* {
              new std::string{$(const Halide::LoopLevel* level')->func()} } |]

instance HasField "var" (LoopLevel 'LockedTy) (Expr Int32) where
  getField :: LoopLevel 'LockedTy -> Expr Int32
getField LoopLevel 'LockedTy
level = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel LoopLevel 'LockedTy
level forall a b. (a -> b) -> a -> b
$ \Ptr CxxLoopLevel
level' ->
      Ptr CxxVarOrRVar -> IO (Expr Int32)
wrapCxxVarOrRVar
        forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::VarOrRVar* {
              new Halide::VarOrRVar{$(const Halide::LoopLevel* level')->var()} } |]

wrapCxxLoopLevel :: Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel :: Ptr CxxLoopLevel -> IO SomeLoopLevel
wrapCxxLoopLevel Ptr CxxLoopLevel
p = do
  [C.throwBlock| void { handle_halide_exceptions([=]() { $(Halide::LoopLevel* p)->lock(); }); } |]
  Bool
inlined <-
    forall a. (Eq a, Num a) => a -> Bool
toBool
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.throwBlock| bool {
            return handle_halide_exceptions([=](){
              return $(const Halide::LoopLevel* p)->is_inlined(); });
          } |]
  Bool
root <-
    forall a. (Eq a, Num a) => a -> Bool
toBool
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [C.throwBlock| bool {
            return handle_halide_exceptions([=](){
              return $(const Halide::LoopLevel* p)->is_root(); });
          } |]
  let level :: IO SomeLoopLevel
level
        | Bool
inlined = [CU.exp| void { delete $(Halide::LoopLevel *p) } |] forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (t :: LoopLevelTy). LoopLevel t -> SomeLoopLevel
SomeLoopLevel LoopLevel 'InlinedTy
InlinedLoopLevel)
        | Bool
root = [CU.exp| void { delete $(Halide::LoopLevel *p) } |] forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (t :: LoopLevelTy). LoopLevel t -> SomeLoopLevel
SomeLoopLevel LoopLevel 'RootTy
RootLoopLevel)
        | Bool
otherwise = do
            let deleter :: FunPtr (Ptr CxxLoopLevel -> IO ())
deleter = [C.funPtr| void deleteLoopLevel(Halide::LoopLevel* p) { delete p; } |]
            forall (t :: LoopLevelTy). LoopLevel t -> SomeLoopLevel
SomeLoopLevel forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr CxxLoopLevel -> LoopLevel 'LockedTy
LoopLevel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxLoopLevel -> IO ())
deleter Ptr CxxLoopLevel
p
  IO SomeLoopLevel
level

withCxxLoopLevel :: LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel :: forall (t :: LoopLevelTy) a.
LoopLevel t -> (Ptr CxxLoopLevel -> IO a) -> IO a
withCxxLoopLevel (LoopLevel ForeignPtr CxxLoopLevel
fp) Ptr CxxLoopLevel -> IO a
action = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxLoopLevel
fp Ptr CxxLoopLevel -> IO a
action
withCxxLoopLevel LoopLevel t
level Ptr CxxLoopLevel -> IO a
action = do
  let allocate :: IO (Ptr CxxLoopLevel)
allocate
        | forall (t :: LoopLevelTy). LoopLevel t -> Bool
isInlined LoopLevel t
level = [CU.exp| Halide::LoopLevel* { new Halide::LoopLevel{Halide::LoopLevel::inlined()} } |]
        | forall (t :: LoopLevelTy). LoopLevel t -> Bool
isRoot LoopLevel t
level = [CU.exp| Halide::LoopLevel* { new Halide::LoopLevel{Halide::LoopLevel::root()} } |]
        | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"this should never happen"
      destroy :: Ptr CxxLoopLevel -> IO ()
destroy Ptr CxxLoopLevel
p = [CU.exp| void { delete $(Halide::LoopLevel *p) } |]
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr CxxLoopLevel)
allocate Ptr CxxLoopLevel -> IO ()
destroy Ptr CxxLoopLevel -> IO a
action