{-# LANGUAGE QuasiQuotes #-}

-- | Generate the entry point packing/unpacking code.
module Futhark.CodeGen.Backends.GenericC.EntryPoints
  ( onEntryPoint,
  )
where

import Control.Monad.Reader
import Data.Maybe
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Monad
import Futhark.CodeGen.Backends.GenericC.Types (opaqueToCType, valueTypeToCType)
import Futhark.CodeGen.ImpCode
import Futhark.Manifest qualified as Manifest
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

valueDescToType :: ValueDesc -> ValueType
valueDescToType :: ValueDesc -> ValueType
valueDescToType (ScalarValue PrimType
pt Signedness
signed VName
_) =
  Signedness -> Rank -> PrimType -> ValueType
ValueType Signedness
signed (Int -> Rank
Rank Int
0) PrimType
pt
valueDescToType (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [DimSize]
shape) =
  Signedness -> Rank -> PrimType -> ValueType
ValueType Signedness
signed (Int -> Rank
Rank (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
shape)) PrimType
pt

allTrue :: [C.Exp] -> C.Exp
allTrue :: [Exp] -> Exp
allTrue [] = [C.cexp|true|]
allTrue [Exp
x] = Exp
x
allTrue (Exp
x : [Exp]
xs) = [C.cexp|$exp:x && $exp:(allTrue xs)|]

prepareEntryInputs ::
  [ExternalValue] ->
  CompilerM op s ([(C.Param, Maybe C.Exp)], [C.BlockItem])
prepareEntryInputs :: forall op s.
[ExternalValue]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
prepareEntryInputs [ExternalValue]
args = forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {p} {op} {s}.
Show p =>
p -> ExternalValue -> CompilerM op s (Param, Maybe Exp)
prepare [(Int
0 :: Int) ..] [ExternalValue]
args
  where
    arg_names :: Names
arg_names = [VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ExternalValue -> [VName]
evNames [ExternalValue]
args
    evNames :: ExternalValue -> [VName]
evNames (OpaqueValue Name
_ [ValueDesc]
vds) = forall a b. (a -> b) -> [a] -> [b]
map ValueDesc -> VName
vdName [ValueDesc]
vds
    evNames (TransparentValue ValueDesc
vd) = [ValueDesc -> VName
vdName ValueDesc
vd]
    vdName :: ValueDesc -> VName
vdName (ArrayValue VName
v Space
_ PrimType
_ Signedness
_ [DimSize]
_) = VName
v
    vdName (ScalarValue PrimType
_ Signedness
_ VName
v) = VName
v

    prepare :: p -> ExternalValue -> CompilerM op s (Param, Maybe Exp)
prepare p
pno (TransparentValue ValueDesc
vd) = do
      let pname :: FilePath
pname = FilePath
"in" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show p
pno
      (Type
ty, [Exp]
check) <- forall {a} {op} {s}.
ToExp a =>
Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
Public [C.cexp|$id:pname|] ValueDesc
vd
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( [C.cparam|const $ty:ty $id:pname|],
          if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
check then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
allTrue [Exp]
check
        )
    prepare p
pno (OpaqueValue Name
desc [ValueDesc]
vds) = do
      Type
ty <- forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
      let pname :: FilePath
pname = FilePath
"in" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show p
pno
          field :: Int -> ValueDesc -> Exp
field Int
i ScalarValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
          field Int
i ArrayValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
      [[Exp]]
checks <- forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (forall {a} {op} {s}.
ToExp a =>
Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
Private) (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> ValueDesc -> Exp
field [Int
0 ..] [ValueDesc]
vds) [ValueDesc]
vds
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( [C.cparam|const $ty:ty *$id:pname|],
          if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[Exp]]
checks
            then forall a. Maybe a
Nothing
            else forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
allTrue forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Exp]]
checks
        )

    prepareValue :: Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
_ a
src (ScalarValue PrimType
pt Signedness
signed VName
name) = do
      let pt' :: Type
pt' = Signedness -> PrimType -> Type
primAPIType Signedness
signed PrimType
pt
          src' :: Exp
src' = PrimType -> Exp -> Exp
fromStorage PrimType
pt forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
src forall a. Monoid a => a
mempty
      forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:name = $exp:src';|]
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Type
pt', [])
    prepareValue Publicness
pub a
src vd :: ValueDesc
vd@(ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [DimSize]
shape) = do
      Type
ty <- forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
pub forall a b. (a -> b) -> a -> b
$ ValueDesc -> ValueType
valueDescToType ValueDesc
vd

      forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = $exp:src->mem;|]

      let rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
shape
          maybeCopyDim :: DimSize -> a -> (Maybe Stm, Exp)
maybeCopyDim (Var VName
d) a
i
            | VName
d VName -> Names -> Bool
`notNameIn` Names
arg_names =
                ( forall a. a -> Maybe a
Just [C.cstm|$id:d = $exp:src->shape[$int:i];|],
                  [C.cexp|$id:d == $exp:src->shape[$int:i]|]
                )
          maybeCopyDim DimSize
x a
i =
            ( forall a. Maybe a
Nothing,
              [C.cexp|$exp:x == $exp:src->shape[$int:i]|]
            )

      let ([Maybe Stm]
sets, [Exp]
checks) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}.
(Show a, Integral a) =>
DimSize -> a -> (Maybe Stm, Exp)
maybeCopyDim [DimSize]
shape [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1]
      forall op s. [Stm] -> CompilerM op s ()
stms forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes [Maybe Stm]
sets

      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([C.cty|$ty:ty*|], [Exp]
checks)

prepareEntryOutputs :: [ExternalValue] -> CompilerM op s ([C.Param], [C.BlockItem])
prepareEntryOutputs :: forall op s.
[ExternalValue] -> CompilerM op s ([Param], [BlockItem])
prepareEntryOutputs = forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {p} {op} {s}.
Show p =>
p -> ExternalValue -> CompilerM op s Param
prepare [(Int
0 :: Int) ..]
  where
    prepare :: p -> ExternalValue -> CompilerM op s Param
prepare p
pno (TransparentValue ValueDesc
vd) = do
      let pname :: FilePath
pname = FilePath
"out" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show p
pno
      Type
ty <- forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
Public forall a b. (a -> b) -> a -> b
$ ValueDesc -> ValueType
valueDescToType ValueDesc
vd

      case ValueDesc
vd of
        ArrayValue {} -> do
          forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]
          forall {a} {op} {s}. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cparam|$ty:ty **$id:pname|]
        ScalarValue {} -> do
          forall {a} {op} {s}. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cparam|$ty:ty *$id:pname|]
    prepare p
pno (OpaqueValue Name
desc [ValueDesc]
vds) = do
      let pname :: FilePath
pname = FilePath
"out" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> FilePath
show p
pno
      Type
ty <- forall op s. Name -> CompilerM op s Type
opaqueToCType Name
desc
      [Type]
vd_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall op s. Publicness -> ValueType -> CompilerM op s Type
valueTypeToCType Publicness
Private forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> ValueType
valueDescToType) [ValueDesc]
vds

      forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]

      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0 ..] [Type]
vd_ts [ValueDesc]
vds) forall a b. (a -> b) -> a -> b
$ \(Int
i, Type
ct, ValueDesc
vd) -> do
        let field :: Exp
field = [C.cexp|((*$id:pname)->$id:(tupleField i))|]
        case ValueDesc
vd of
          ScalarValue {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
          ArrayValue {} -> do
            forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert(($exp:field = ($ty:ct*) malloc(sizeof($ty:ct))) != NULL);|]
        forall {a} {op} {s}. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue Exp
field ValueDesc
vd

      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cparam|$ty:ty **$id:pname|]

    prepareValue :: a -> ValueDesc -> CompilerM op s ()
prepareValue a
dest (ScalarValue PrimType
t Signedness
_ VName
name) =
      let name' :: Exp
name' = PrimType -> Exp -> Exp
toStorage PrimType
t forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
name forall a. Monoid a => a
mempty
       in forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $exp:name';|]
    prepareValue a
dest (ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [DimSize]
shape) = do
      forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest->mem = $id:mem;|]

      let rank :: Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
shape
          maybeCopyDim :: DimSize -> a -> Stm
maybeCopyDim (Constant PrimValue
x) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $exp:x;|]
          maybeCopyDim (Var VName
d) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $id:d;|]
      forall op s. [Stm] -> CompilerM op s ()
stms forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. (Show a, Integral a) => DimSize -> a -> Stm
maybeCopyDim [DimSize]
shape [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1]

entryName :: Name -> T.Text
entryName :: Name -> Text
entryName = (Text
"entry_" <>) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
escapeName forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Text
nameToText

onEntryPoint ::
  [C.BlockItem] ->
  Name ->
  Function op ->
  CompilerM op s (Maybe (C.Definition, (T.Text, Manifest.EntryPoint)))
onEntryPoint :: forall op s.
[BlockItem]
-> Name
-> Function op
-> CompilerM op s (Maybe (Definition, (Text, EntryPoint)))
onEntryPoint [BlockItem]
_ Name
_ (Function Maybe EntryPoint
Nothing [Param]
_ [Param]
_ Code op
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
onEntryPoint [BlockItem]
get_consts Name
fname (Function (Just (EntryPoint Name
ename [(Uniqueness, ExternalValue)]
results [((Name, Uniqueness), ExternalValue)]
args)) [Param]
outputs [Param]
inputs Code op
_) = forall op s a. CompilerM op s a -> CompilerM op s a
inNewFunction forall a b. (a -> b) -> a -> b
$ do
  let out_args :: [Exp]
out_args = forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|&$id:(paramName p)|]) [Param]
outputs
      in_args :: [Exp]
in_args = forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|$id:(paramName p)|]) [Param]
inputs

  [BlockItem]
inputdecls <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {op} {s}. Param -> CompilerM op s ()
stubParam [Param]
inputs
  [BlockItem]
outputdecls <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {op} {s}. Param -> CompilerM op s ()
stubParam [Param]
outputs
  [BlockItem]
decl_mem <- forall op s. CompilerM op s [BlockItem]
declAllocatedMem

  Text
entry_point_function_name <- forall op s. Text -> CompilerM op s Text
publicName forall a b. (a -> b) -> a -> b
$ Name -> Text
entryName Name
ename

  ([(Param, Maybe Exp)]
inputs', [BlockItem]
unpack_entry_inputs) <- forall op s.
[ExternalValue]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
prepareEntryInputs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [((Name, Uniqueness), ExternalValue)]
args
  let ([Param]
entry_point_input_params, [Maybe Exp]
entry_point_input_checks) = forall a b. [(a, b)] -> ([a], [b])
unzip [(Param, Maybe Exp)]
inputs'

  ([Param]
entry_point_output_params, [BlockItem]
pack_entry_outputs) <-
    forall op s.
[ExternalValue] -> CompilerM op s ([Param], [BlockItem])
prepareEntryOutputs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(Uniqueness, ExternalValue)]
results

  Type
ctx_ty <- forall op s. CompilerM op s Type
contextType

  forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    HeaderSection
EntryDecl
    [C.cedecl|int $id:entry_point_function_name
                                     ($ty:ctx_ty *ctx,
                                      $params:entry_point_output_params,
                                      $params:entry_point_input_params);|]

  let checks :: [Exp]
checks = forall a. [Maybe a] -> [a]
catMaybes [Maybe Exp]
entry_point_input_checks
      check_input :: [BlockItem]
check_input =
        if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
checks
          then []
          else
            [C.citems|
         if (!($exp:(allTrue (catMaybes entry_point_input_checks)))) {
           ret = 1;
           set_error(ctx, msgprintf("Error: entry point arguments have invalid sizes.\n"));
         }|]

      critical :: [BlockItem]
critical =
        [C.citems|
         $items:decl_mem
         $items:unpack_entry_inputs
         $items:check_input
         if (ret == 0) {
           ret = $id:(funName fname)(ctx, $args:out_args, $args:in_args);
           if (ret == 0) {
             $items:get_consts

             $items:pack_entry_outputs
           }
         }
        |]

  Operations op s
ops <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall op s. CompilerEnv op s -> Operations op s
envOperations

  let cdef :: Definition
cdef =
        [C.cedecl|
       int $id:entry_point_function_name
           ($ty:ctx_ty *ctx,
            $params:entry_point_output_params,
            $params:entry_point_input_params) {
         $items:inputdecls
         $items:outputdecls

         int ret = 0;

         $items:(criticalSection ops critical)

         return ret;
       }|]

      manifest :: EntryPoint
manifest =
        Manifest.EntryPoint
          { entryPointCFun :: Text
Manifest.entryPointCFun = Text
entry_point_function_name,
            -- Note that our convention about what is "input/output"
            -- and what is "results/args" is different between the
            -- manifest and ImpCode.
            entryPointOutputs :: [Output]
Manifest.entryPointOutputs = forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness, ExternalValue) -> Output
outputManifest [(Uniqueness, ExternalValue)]
results,
            entryPointInputs :: [Input]
Manifest.entryPointInputs = forall a b. (a -> b) -> [a] -> [b]
map ((Name, Uniqueness), ExternalValue) -> Input
inputManifest [((Name, Uniqueness), ExternalValue)]
args
          }

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Definition
cdef, (Name -> Text
nameToText Name
ename, EntryPoint
manifest))
  where
    stubParam :: Param -> CompilerM op s ()
stubParam (MemParam VName
name Space
space) =
      forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    stubParam (ScalarParam VName
name PrimType
ty) = do
      let ty' :: Type
ty' = PrimType -> Type
primTypeToCType PrimType
ty
      forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty' $id:name;|]

    vdType :: ExternalValue -> Text
vdType (TransparentValue (ScalarValue PrimType
pt Signedness
signed VName
_)) =
      Bool -> PrimType -> Text
prettySigned (Signedness
signed forall a. Eq a => a -> a -> Bool
== Signedness
Unsigned) PrimType
pt
    vdType (TransparentValue (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [DimSize]
shape)) =
      forall a. Monoid a => [a] -> a
mconcat (forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
shape) Text
"[]")
        forall a. Semigroup a => a -> a -> a
<> Bool -> PrimType -> Text
prettySigned (Signedness
signed forall a. Eq a => a -> a -> Bool
== Signedness
Unsigned) PrimType
pt
    vdType (OpaqueValue Name
name [ValueDesc]
_) =
      Name -> Text
nameToText Name
name

    outputManifest :: (Uniqueness, ExternalValue) -> Output
outputManifest (Uniqueness
u, ExternalValue
vd) =
      Manifest.Output
        { outputType :: Text
Manifest.outputType = ExternalValue -> Text
vdType ExternalValue
vd,
          outputUnique :: Bool
Manifest.outputUnique = Uniqueness
u forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique
        }
    inputManifest :: ((Name, Uniqueness), ExternalValue) -> Input
inputManifest ((Name
v, Uniqueness
u), ExternalValue
vd) =
      Manifest.Input
        { inputName :: Text
Manifest.inputName = Name -> Text
nameToText Name
v,
          inputType :: Text
Manifest.inputType = ExternalValue -> Text
vdType ExternalValue
vd,
          inputUnique :: Bool
Manifest.inputUnique = Uniqueness
u forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique
        }