{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.CodeGen.Generate
  where

import Prelude                                                  hiding ( fromIntegral )

-- accelerate
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar

import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Loop
import Data.Array.Accelerate.LLVM.PTX.Target                    ( PTX )


-- Construct a new array by applying a function to each index. Each thread
-- processes multiple adjacent elements.
--
mkGenerate
    :: Gamma aenv
    -> ArrayR (Array sh e)
    -> IRFun1  PTX aenv (sh -> e)
    -> CodeGen PTX      (IROpenAcc PTX aenv (Array sh e))
mkGenerate :: Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 PTX aenv (sh -> e)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
mkGenerate Gamma aenv
aenv repr :: ArrayR (Array sh e)
repr@(ArrayR ShapeR sh
shr TypeR e
_) IRFun1 PTX aenv (sh -> e)
apply =
  let
      (IRArray (Array sh e)
arrOut, [Parameter]
paramOut)  = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh e)
repr Name (Array sh e)
"out"
      paramEnv :: [Parameter]
paramEnv            = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
  in
  Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall aenv a.
Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc Label
"generate" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh e)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh e))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int
start <- Operands Int -> CodeGen PTX (Operands Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Operands Int
liftInt Int
0)
    Operands Int
end   <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr (IRArray (Array sh e) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh e)
arrOut)

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do
      Operands sh
ix <- ShapeR sh
-> Operands sh -> Operands Int -> CodeGen PTX (Operands sh)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands Int -> CodeGen arch (Operands sh)
indexOfInt ShapeR sh
shr (IRArray (Array sh e) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh e)
arrOut) Operands Int
i      -- convert to multidimensional index
      Operands e
r  <- IRFun1 PTX aenv (sh -> e)
-> Operands sh -> IROpenExp PTX ((), sh) aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 PTX aenv (sh -> e)
apply Operands sh
Operands sh
ix                               -- apply generator function
      IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> Operands e
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh e)
arrOut Operands Int
i Operands e
r                     -- store result

    CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_