{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.PTX.Execute.Marshal
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.PTX.Execute.Marshal (

  module Data.Array.Accelerate.LLVM.Execute.Marshal

) where

import Data.Array.Accelerate.LLVM.State
import Data.Array.Accelerate.LLVM.Execute.Marshal

import Data.Array.Accelerate.LLVM.PTX.Target
import Data.Array.Accelerate.LLVM.PTX.Execute.Async
import qualified Data.Array.Accelerate.LLVM.PTX.Array.Prim      as Prim

import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Array.Data

import qualified Foreign.CUDA.Driver                            as CUDA

import qualified Data.DList                                     as DL


instance Marshal PTX where
  type ArgR PTX = CUDA.FunParam

  marshalInt :: Int -> ArgR PTX
marshalInt = Int -> ArgR PTX
forall a. Storable a => a -> FunParam
CUDA.VArg
  marshalScalarData' :: SingleType e -> ScalarArrayData e -> Par PTX (DList (ArgR PTX))
marshalScalarData' SingleType e
t
    | SingleArrayDict e
SingleArrayDict <- SingleType e -> SingleArrayDict e
forall a. SingleType a -> SingleArrayDict a
singleArrayDict SingleType e
t
    = LLVM PTX (DList FunParam) -> Par PTX (DList FunParam)
forall arch a.
(Async arch, HasCallStack) =>
LLVM arch a -> Par arch a
liftPar (LLVM PTX (DList FunParam) -> Par PTX (DList FunParam))
-> (UniqueArray e -> LLVM PTX (DList FunParam))
-> UniqueArray e
-> Par PTX (DList FunParam)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DevicePtr e -> DList FunParam)
-> LLVM PTX (DevicePtr e) -> LLVM PTX (DList FunParam)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (FunParam -> DList FunParam
forall a. a -> DList a
DL.singleton (FunParam -> DList FunParam)
-> (DevicePtr e -> FunParam) -> DevicePtr e -> DList FunParam
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DevicePtr e -> FunParam
forall a. Storable a => a -> FunParam
CUDA.VArg) (LLVM PTX (DevicePtr e) -> LLVM PTX (DList FunParam))
-> (UniqueArray e -> LLVM PTX (DevicePtr e))
-> UniqueArray e
-> LLVM PTX (DList FunParam)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SingleType e
-> ArrayData e -> LLVM PTX (DevicePtr (ScalarArrayDataR e))
forall e.
SingleType e
-> ArrayData e -> LLVM PTX (DevicePtr (ScalarArrayDataR e))
unsafeGetDevicePtr SingleType e
t

-- TODO FIXME !!!
--
-- We will probably need to change marshal to be a bracketed function, so that
-- the garbage collector does not try to evict the array in the middle of
-- a computation.
--
unsafeGetDevicePtr
    :: SingleType e
    -> ArrayData e
    -> LLVM PTX (CUDA.DevicePtr (ScalarArrayDataR e))
unsafeGetDevicePtr :: SingleType e
-> ArrayData e -> LLVM PTX (DevicePtr (ScalarArrayDataR e))
unsafeGetDevicePtr !SingleType e
t !ArrayData e
ad =
  SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e)
    -> LLVM PTX (Maybe Event, DevicePtr (ScalarArrayDataR e)))
-> LLVM PTX (DevicePtr (ScalarArrayDataR e))
forall e r.
HasCallStack =>
SingleType e
-> ArrayData e
-> (DevicePtr (ScalarArrayDataR e) -> LLVM PTX (Maybe Event, r))
-> LLVM PTX r
Prim.withDevicePtr SingleType e
t ArrayData e
ad (\DevicePtr (ScalarArrayDataR e)
p -> (Maybe Event, DevicePtr (ScalarArrayDataR e))
-> LLVM PTX (Maybe Event, DevicePtr (ScalarArrayDataR e))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Event
forall a. Maybe a
Nothing, DevicePtr (ScalarArrayDataR e)
p))