{-# LANGUAGE QuasiQuotes #-}

-- | C code generation for functions.
module Futhark.CodeGen.Backends.GenericC.Fun
  ( compileFun,
    module Futhark.CodeGen.Backends.GenericC.Monad,
    module Futhark.CodeGen.Backends.GenericC.Code,
  )
where

import Control.Monad.Reader
import Futhark.CodeGen.Backends.GenericC.Code
import Futhark.CodeGen.Backends.GenericC.Monad
import Futhark.CodeGen.ImpCode
import Futhark.MonadFreshNames
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

compileFunBody :: [C.Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody :: forall op s. [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
output_ptrs [Param]
outputs Code op
code = do
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {op} {s}. Param -> CompilerM op s ()
declareOutput [Param]
outputs
  forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {a} {op} {s}. ToExp a => a -> Param -> CompilerM op s ()
setRetVal' [Exp]
output_ptrs [Param]
outputs
  where
    declareOutput :: Param -> CompilerM op s ()
declareOutput (MemParam VName
name Space
space) =
      forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    declareOutput (ScalarParam VName
name PrimType
pt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
pt
      forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ctp $id:name;|]

    setRetVal' :: a -> Param -> CompilerM op s ()
setRetVal' a
p (MemParam VName
name Space
space) = do
      forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|*$exp:p|] Space
space
      forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem [C.cexp|*$exp:p|] VName
name Space
space
    setRetVal' a
p (ScalarParam VName
name PrimType
_) =
      forall op s. Stm -> CompilerM op s ()
stm [C.cstm|*$exp:p = $id:name;|]

compileFun :: [C.BlockItem] -> [C.Param] -> (Name, Function op) -> CompilerM op s (C.Definition, C.Func)
compileFun :: forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_constants [Param]
extra (Name
fname, func :: Function op
func@(Function Maybe EntryPoint
_ [Param]
outputs [Param]
inputs Code op
body)) = forall op s a. CompilerM op s a -> CompilerM op s a
inNewFunction forall a b. (a -> b) -> a -> b
$ do
  ([Param]
outparams, [Exp]
out_ptrs) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {op} {s}. Param -> CompilerM op s (Param, Exp)
compileOutput [Param]
outputs
  [Param]
inparams <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {op} {s}. Param -> CompilerM op s Param
compileInput [Param]
inputs

  forall op s a.
Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory (forall a. Function a -> Map VName Space
lexicalMemoryUsage Function op
func) forall a b. (a -> b) -> a -> b
$ \[BlockItem]
decl_cached [Stm]
free_cached -> do
    [BlockItem]
body' <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect forall a b. (a -> b) -> a -> b
$ forall op s. [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
out_ptrs [Param]
outputs Code op
body
    [BlockItem]
decl_mem <- forall op s. CompilerM op s [BlockItem]
declAllocatedMem
    [BlockItem]
free_mem <- forall op s. CompilerM op s [BlockItem]
freeAllocatedMem

    forall (f :: * -> *) a. Applicative f => a -> f a
pure
      ( [C.cedecl|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams);|],
        [C.cfun|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams) {
               $stms:ignores
               int err = 0;
               $items:decl_cached
               $items:decl_mem
               $items:get_constants
               $items:body'
              cleanup:
               {
               $stms:free_cached
               $items:free_mem
               }
               return err;
  }|]
      )
  where
    -- Ignore all the boilerplate parameters, just in case we don't
    -- actually need to use them.
    ignores :: [Stm]
ignores = [[C.cstm|(void)$id:p;|] | C.Param (Just Id
p) DeclSpec
_ Decl
_ SrcLoc
_ <- [Param]
extra]

    compileInput :: Param -> CompilerM op s Param
compileInput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cparam|$ty:ctp $id:name|]
    compileInput (MemParam VName
name Space
space) = do
      Type
ty <- forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cparam|$ty:ty $id:name|]

    compileOutput :: Param -> CompilerM op s (Param, Exp)
compileOutput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      VName
p_name <- forall (m :: * -> *). MonadFreshNames m => FilePath -> m VName
newVName forall a b. (a -> b) -> a -> b
$ FilePath
"out_" forall a. [a] -> [a] -> [a]
++ VName -> FilePath
baseString VName
name
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([C.cparam|$ty:ctp *$id:p_name|], [C.cexp|$id:p_name|])
    compileOutput (MemParam VName
name Space
space) = do
      Type
ty <- forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      VName
p_name <- forall (m :: * -> *). MonadFreshNames m => FilePath -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> FilePath
baseString VName
name forall a. [a] -> [a] -> [a]
++ FilePath
"_p"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([C.cparam|$ty:ty *$id:p_name|], [C.cexp|$id:p_name|])