{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

-- |
-- Module      : Language.Halide.Dimension
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Dimension
  ( Dimension (..)
  , setMin
  , setExtent
  , setStride
  , setEstimate

    -- * Internal
  , CxxDimension
  , wrapCxxDimension
  , withCxxDimension
  )
where

import Foreign.ForeignPtr
import Foreign.Ptr (Ptr)
import GHC.Records (HasField (..))
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Unsafe as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Type
import System.IO.Unsafe (unsafePerformIO)
import Prelude hiding (tail)

-- | Haskell counterpart of [@Halide::Internal::Dimension@](https://halide-lang.org/docs/class_halide_1_1_internal_1_1_dimension.html).
data CxxDimension

importHalide

-- | Information about a buffer's dimension, such as the min, extent, and stride.
newtype Dimension = Dimension (ForeignPtr CxxDimension)

instance Show Dimension where
  showsPrec :: Int -> Dimension -> ShowS
showsPrec Int
d Dimension
dim =
    Bool -> ShowS -> ShowS
showParen (Int
d forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$
      String -> ShowS
showString String
"Dimension { min="
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows Dimension
dim.min
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString (String
", extent=" :: String)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows Dimension
dim.extent
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString (String
", stride=" :: String)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> ShowS
shows Dimension
dim.stride
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" }"

instance HasField "min" Dimension (Expr Int32) where
  getField :: Dimension -> Expr Int32
  getField :: Dimension -> Expr Int32
getField Dimension
dim = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(const Halide::Internal::Dimension* d)->min()} } |]

-- | Set the min in a given dimension to equal the given expression. Setting the mins to
-- zero may simplify some addressing math.
--
-- For more info, see [Halide::Internal::Dimension::set_min](https://halide-lang.org/docs/class_halide_1_1_internal_1_1_dimension.html#a84acaf7733391fdaea4f4cec24a60de2).
setMin :: Expr Int32 -> Dimension -> IO Dimension
setMin :: Expr Int32 -> Dimension -> IO Dimension
setMin Expr Int32
expr Dimension
dim = do
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
n ->
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      [CU.exp| void {
        $(Halide::Internal::Dimension* d)->set_min(*$(const Halide::Expr* n)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Dimension
dim

instance HasField "extent" Dimension (Expr Int32) where
  getField :: Dimension -> Expr Int32
  getField :: Dimension -> Expr Int32
getField Dimension
dim = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(const Halide::Internal::Dimension* d)->extent()} } |]

-- | Set the extent in a given dimension to equal the given expression.
--
-- Halide will generate runtime errors for Buffers that fail this check.
--
-- For more info, see [Halide::Internal::Dimension::set_extent](https://halide-lang.org/docs/class_halide_1_1_internal_1_1_dimension.html#a54111d8439a065bdaca5b9ff9bcbd630).
setExtent :: Expr Int32 -> Dimension -> IO Dimension
setExtent :: Expr Int32 -> Dimension -> IO Dimension
setExtent Expr Int32
expr Dimension
dim = do
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
n ->
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      [CU.exp| void {
        $(Halide::Internal::Dimension* d)->set_extent(*$(const Halide::Expr* n)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Dimension
dim

instance HasField "max" Dimension (Expr Int32) where
  getField :: Dimension -> Expr Int32
  getField :: Dimension -> Expr Int32
getField Dimension
dim = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(Halide::Internal::Dimension* d)->max()} } |]

instance HasField "stride" Dimension (Expr Int32) where
  getField :: Dimension -> Expr Int32
  getField :: Dimension -> Expr Int32
getField Dimension
dim = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      forall a.
(HasCallStack, IsHalideType a) =>
(Ptr CxxExpr -> IO ()) -> IO (Expr a)
cxxConstructExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
ptr ->
        [CU.exp| void { new ($(Halide::Expr* ptr)) Halide::Expr{
          $(Halide::Internal::Dimension* d)->stride()} } |]

-- | Set the stride in a given dimension to equal the given expression.
--
-- This is particularly useful to set when vectorizing. Known strides for the vectorized
-- dimensions generate better code.
--
-- For more info, see [Halide::Internal::Dimension::set_stride](https://halide-lang.org/docs/class_halide_1_1_internal_1_1_dimension.html#a94f4c432a89907e2cc2aa908b5012cf8).
setStride :: Expr Int32 -> Dimension -> IO Dimension
setStride :: Expr Int32 -> Dimension -> IO Dimension
setStride Expr Int32
expr Dimension
dim = do
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
n ->
    forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
      [CU.exp| void {
        $(Halide::Internal::Dimension* d)->set_stride(*$(const Halide::Expr* n)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Dimension
dim

-- | Set estimates for autoschedulers.
setEstimate
  :: Expr Int32
  -- ^ @min@ estimate
  -> Expr Int32
  -- ^ @extent@ estimate
  -> Dimension
  -> IO Dimension
setEstimate :: Expr Int32 -> Expr Int32 -> Dimension -> IO Dimension
setEstimate Expr Int32
minExpr Expr Int32
extentExpr Dimension
dim = do
  forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
minExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
m ->
    forall a b.
IsHalideType a =>
Expr a -> (Ptr CxxExpr -> IO b) -> IO b
asExpr Expr Int32
extentExpr forall a b. (a -> b) -> a -> b
$ \Ptr CxxExpr
e ->
      forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension Dimension
dim forall a b. (a -> b) -> a -> b
$ \Ptr CxxDimension
d ->
        [CU.exp| void {
          $(Halide::Internal::Dimension* d)->set_estimate(*$(const Halide::Expr* m),
                                                          *$(const Halide::Expr* e)) } |]
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Dimension
dim

wrapCxxDimension :: Ptr CxxDimension -> IO Dimension
wrapCxxDimension :: Ptr CxxDimension -> IO Dimension
wrapCxxDimension = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ForeignPtr CxxDimension -> Dimension
Dimension forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxDimension -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxDimension -> IO ())
deleter = [C.funPtr| void deleteDimension(Halide::Internal::Dimension* p) { delete p; } |]

withCxxDimension :: Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension :: forall a. Dimension -> (Ptr CxxDimension -> IO a) -> IO a
withCxxDimension (Dimension ForeignPtr CxxDimension
fp) = forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CxxDimension
fp