{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Execute.Marshal
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Execute.Marshal
  where

import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Environment           ( Gamma, Idx'(..) )
import Data.Array.Accelerate.LLVM.Execute.Environment
import Data.Array.Accelerate.LLVM.Execute.Async

import Data.DList                                               ( DList )
import qualified Data.DList                                     as DL
import qualified Data.IntMap                                    as IM


-- Marshalling arguments
-- ---------------------
class Async arch => Marshal arch where
  -- | A type family that is used to specify a concrete kernel argument and
  -- stream/context type for a given backend target.
  --
  type ArgR arch

  -- | Used to pass shapes as arguments to kernels.
  marshalInt :: Int -> ArgR arch

  -- | Pass arrays to kernels
  marshalScalarData' :: SingleType e -> ScalarArrayData e -> Par arch (DList (ArgR arch))

-- | Convert function arguments into stream a form suitable for function calls
-- The functions ending in a prime return a DList, other functions return lists.
--
marshalArrays :: forall arch arrs. Marshal arch => ArraysR arrs -> arrs -> Par arch [ArgR arch]
marshalArrays :: ArraysR arrs -> arrs -> Par arch [ArgR arch]
marshalArrays ArraysR arrs
repr arrs
arrs = DList (ArgR arch) -> [ArgR arch]
forall a. DList a -> [a]
DL.toList (DList (ArgR arch) -> [ArgR arch])
-> Par arch (DList (ArgR arch)) -> Par arch [ArgR arch]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArraysR arrs -> arrs -> Par arch (DList (ArgR arch))
forall arch arrs.
Marshal arch =>
ArraysR arrs -> arrs -> Par arch (DList (ArgR arch))
marshalArrays' @arch ArraysR arrs
repr arrs
arrs

marshalArrays' :: forall arch arrs. Marshal arch => ArraysR arrs -> arrs -> Par arch (DList (ArgR arch))
marshalArrays' :: ArraysR arrs -> arrs -> Par arch (DList (ArgR arch))
marshalArrays' = (forall b. ArrayR b -> b -> Par arch (DList (ArgR arch)))
-> ArraysR arrs -> arrs -> Par arch (DList (ArgR arch))
forall arch (s :: * -> *) a.
Marshal arch =>
(forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s a -> a -> Par arch (DList (ArgR arch))
marshalTupR' @arch (forall a.
Marshal arch =>
ArrayR a -> a -> Par arch (DList (ArgR arch))
forall arch a.
Marshal arch =>
ArrayR a -> a -> Par arch (DList (ArgR arch))
marshalArray' @arch)

marshalArray' :: forall arch a. Marshal arch => ArrayR a -> a -> Par arch (DList (ArgR arch))
marshalArray' :: ArrayR a -> a -> Par arch (DList (ArgR arch))
marshalArray' (ArrayR ShapeR sh
shr TypeR e
tp) (Array sh a) = do
  DList (ArgR arch)
arg1 <- TypeR e -> ArrayData e -> Par arch (DList (ArgR arch))
forall arch t.
Marshal arch =>
TypeR t -> ArrayData t -> Par arch (DList (ArgR arch))
marshalArrayData' @arch TypeR e
tp ArrayData e
a
  let arg2 :: DList (ArgR arch)
arg2 = ShapeR sh -> sh -> DList (ArgR arch)
forall arch sh.
Marshal arch =>
ShapeR sh -> sh -> DList (ArgR arch)
marshalShape' @arch ShapeR sh
shr sh
sh
  DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return (DList (ArgR arch) -> Par arch (DList (ArgR arch)))
-> DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ DList (ArgR arch)
arg1 DList (ArgR arch) -> DList (ArgR arch) -> DList (ArgR arch)
forall a. DList a -> DList a -> DList a
`DL.append` DList (ArgR arch)
arg2

marshalArrayData' :: forall arch t. Marshal arch => TypeR t -> ArrayData t -> Par arch (DList (ArgR arch))
marshalArrayData' :: TypeR t -> ArrayData t -> Par arch (DList (ArgR arch))
marshalArrayData' TypeR t
TupRunit ()               = DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return DList (ArgR arch)
forall a. DList a
DL.empty
marshalArrayData' (TupRpair TupR ScalarType a1
t1 TupR ScalarType b
t2) (a1, a2) = do
  DList (ArgR arch)
l1 <- TupR ScalarType a1 -> ArrayData a1 -> Par arch (DList (ArgR arch))
forall arch t.
Marshal arch =>
TypeR t -> ArrayData t -> Par arch (DList (ArgR arch))
marshalArrayData' TupR ScalarType a1
t1 ArrayData a1
a1
  DList (ArgR arch)
l2 <- TupR ScalarType b -> ArrayData b -> Par arch (DList (ArgR arch))
forall arch t.
Marshal arch =>
TypeR t -> ArrayData t -> Par arch (DList (ArgR arch))
marshalArrayData' TupR ScalarType b
t2 ArrayData b
a2
  DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return (DList (ArgR arch) -> Par arch (DList (ArgR arch)))
-> DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ DList (ArgR arch)
l1 DList (ArgR arch) -> DList (ArgR arch) -> DList (ArgR arch)
forall a. DList a -> DList a -> DList a
`DL.append` DList (ArgR arch)
l2
marshalArrayData' (TupRsingle ScalarType t
t) ArrayData t
ad
  | ScalarArrayDict Int
_ SingleType b
s <- ScalarType t -> ScalarArrayDict t
forall a. ScalarType a -> ScalarArrayDict a
scalarArrayDict ScalarType t
t
  = SingleType b -> ScalarArrayData b -> Par arch (DList (ArgR arch))
forall arch e.
Marshal arch =>
SingleType e -> ScalarArrayData e -> Par arch (DList (ArgR arch))
marshalScalarData' @arch SingleType b
s ArrayData t
ScalarArrayData b
ad

marshalEnv :: forall arch aenv. Marshal arch => Gamma aenv -> ValR arch aenv -> Par arch [ArgR arch]
marshalEnv :: Gamma aenv -> ValR arch aenv -> Par arch [ArgR arch]
marshalEnv Gamma aenv
g ValR arch aenv
a = DList (ArgR arch) -> [ArgR arch]
forall a. DList a -> [a]
DL.toList (DList (ArgR arch) -> [ArgR arch])
-> Par arch (DList (ArgR arch)) -> Par arch [ArgR arch]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch))
forall arch aenv.
Marshal arch =>
Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch))
marshalEnv' Gamma aenv
g ValR arch aenv
a

marshalEnv' :: forall arch aenv. Marshal arch => Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch))
marshalEnv' :: Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch))
marshalEnv' Gamma aenv
gamma ValR arch aenv
aenv
    = ([DList (ArgR arch)] -> DList (ArgR arch))
-> Par arch [DList (ArgR arch)] -> Par arch (DList (ArgR arch))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [DList (ArgR arch)] -> DList (ArgR arch)
forall a. [DList a] -> DList a
DL.concat
    (Par arch [DList (ArgR arch)] -> Par arch (DList (ArgR arch)))
-> Par arch [DList (ArgR arch)] -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ ((Label, Idx' aenv) -> Par arch (DList (ArgR arch)))
-> [(Label, Idx' aenv)] -> Par arch [DList (ArgR arch)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(Label
_, Idx' ArrayR (Array sh e)
repr Idx aenv (Array sh e)
idx) -> ArrayR (Array sh e) -> Array sh e -> Par arch (DList (ArgR arch))
forall arch a.
Marshal arch =>
ArrayR a -> a -> Par arch (DList (ArgR arch))
marshalArray' @arch ArrayR (Array sh e)
repr (Array sh e -> Par arch (DList (ArgR arch)))
-> Par arch (Array sh e) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FutureR arch (Array sh e) -> Par arch (Array sh e)
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get (Idx aenv (Array sh e)
-> ValR arch aenv -> FutureR arch (Array sh e)
forall env t arch. Idx env t -> ValR arch env -> FutureR arch t
prj Idx aenv (Array sh e)
idx ValR arch aenv
aenv)) (Gamma aenv -> [(Label, Idx' aenv)]
forall a. IntMap a -> [a]
IM.elems Gamma aenv
gamma)

marshalShape :: forall arch sh. Marshal arch => ShapeR sh -> sh -> [ArgR arch]
marshalShape :: ShapeR sh -> sh -> [ArgR arch]
marshalShape ShapeR sh
shr sh
sh = DList (ArgR arch) -> [ArgR arch]
forall a. DList a -> [a]
DL.toList (DList (ArgR arch) -> [ArgR arch])
-> DList (ArgR arch) -> [ArgR arch]
forall a b. (a -> b) -> a -> b
$ ShapeR sh -> sh -> DList (ArgR arch)
forall arch sh.
Marshal arch =>
ShapeR sh -> sh -> DList (ArgR arch)
marshalShape' @arch ShapeR sh
shr sh
sh

marshalShape' :: forall arch sh. Marshal arch => ShapeR sh -> sh -> DList (ArgR arch)
marshalShape' :: ShapeR sh -> sh -> DList (ArgR arch)
marshalShape' ShapeR sh
ShapeRz () = DList (ArgR arch)
forall a. DList a
DL.empty
marshalShape' (ShapeRsnoc ShapeR sh1
shr) (sh, n) = ShapeR sh1 -> sh1 -> DList (ArgR arch)
forall arch sh.
Marshal arch =>
ShapeR sh -> sh -> DList (ArgR arch)
marshalShape' @arch ShapeR sh1
shr sh1
sh DList (ArgR arch) -> ArgR arch -> DList (ArgR arch)
forall a. DList a -> a -> DList a
`DL.snoc` Int -> ArgR arch
forall arch. Marshal arch => Int -> ArgR arch
marshalInt @arch Int
n

type ParamsR arch = TupR (ParamR arch)

data ParamR arch a where
  ParamRarray  :: ArrayR (Array sh e) -> ParamR arch (Array sh e)
  ParamRmaybe  :: ParamR arch a       -> ParamR arch (Maybe a)
  ParamRfuture :: ParamR arch a       -> ParamR arch (FutureR arch a)
  ParamRenv    :: Gamma aenv          -> ParamR arch (ValR arch aenv)
  ParamRint    ::                        ParamR arch Int
  ParamRshape  :: ShapeR sh           -> ParamR arch sh
  ParamRargs   ::                        ParamR arch (DList (ArgR arch))

marshalParam' :: forall arch a. Marshal arch => ParamR arch a -> a -> Par arch (DList (ArgR arch))
marshalParam' :: ParamR arch a -> a -> Par arch (DList (ArgR arch))
marshalParam' (ParamRarray ArrayR (Array sh e)
repr)  a
a        = ArrayR (Array sh e) -> Array sh e -> Par arch (DList (ArgR arch))
forall arch a.
Marshal arch =>
ArrayR a -> a -> Par arch (DList (ArgR arch))
marshalArray' ArrayR (Array sh e)
repr a
Array sh e
a
marshalParam' (ParamRmaybe ParamR arch a
_   )  a
Nothing  = DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return (DList (ArgR arch) -> Par arch (DList (ArgR arch)))
-> DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ DList (ArgR arch)
forall a. DList a
DL.empty
marshalParam' (ParamRmaybe ParamR arch a
repr)  (Just a) = ParamR arch a -> a -> Par arch (DList (ArgR arch))
forall arch a.
Marshal arch =>
ParamR arch a -> a -> Par arch (DList (ArgR arch))
marshalParam' ParamR arch a
repr a
a
marshalParam' (ParamRfuture ParamR arch a
repr) a
future   = ParamR arch a -> a -> Par arch (DList (ArgR arch))
forall arch a.
Marshal arch =>
ParamR arch a -> a -> Par arch (DList (ArgR arch))
marshalParam' ParamR arch a
repr (a -> Par arch (DList (ArgR arch)))
-> Par arch a -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< FutureR arch a -> Par arch a
forall arch a.
(Async arch, HasCallStack) =>
FutureR arch a -> Par arch a
get a
FutureR arch a
future
marshalParam' (ParamRenv Gamma aenv
gamma)   a
aenv     = Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch))
forall arch aenv.
Marshal arch =>
Gamma aenv -> ValR arch aenv -> Par arch (DList (ArgR arch))
marshalEnv'   Gamma aenv
gamma a
ValR arch aenv
aenv
marshalParam'  ParamR arch a
ParamRint          a
x        = DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return (DList (ArgR arch) -> Par arch (DList (ArgR arch)))
-> DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ ArgR arch -> DList (ArgR arch)
forall a. a -> DList a
DL.singleton (ArgR arch -> DList (ArgR arch)) -> ArgR arch -> DList (ArgR arch)
forall a b. (a -> b) -> a -> b
$ Int -> ArgR arch
forall arch. Marshal arch => Int -> ArgR arch
marshalInt @arch a
Int
x
marshalParam' (ParamRshape ShapeR a
shr)   a
sh       = DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return (DList (ArgR arch) -> Par arch (DList (ArgR arch)))
-> DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ ShapeR a -> a -> DList (ArgR arch)
forall arch sh.
Marshal arch =>
ShapeR sh -> sh -> DList (ArgR arch)
marshalShape' @arch ShapeR a
shr a
sh
marshalParam'  ParamR arch a
ParamRargs         a
args     = a -> Par arch a
forall (m :: * -> *) a. Monad m => a -> m a
return a
args

marshalParams' :: forall arch a. Marshal arch => ParamsR arch a -> a -> Par arch (DList (ArgR arch))
marshalParams' :: ParamsR arch a -> a -> Par arch (DList (ArgR arch))
marshalParams' = (forall b. ParamR arch b -> b -> Par arch (DList (ArgR arch)))
-> ParamsR arch a -> a -> Par arch (DList (ArgR arch))
forall arch (s :: * -> *) a.
Marshal arch =>
(forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s a -> a -> Par arch (DList (ArgR arch))
marshalTupR' @arch (forall a.
Marshal arch =>
ParamR arch a -> a -> Par arch (DList (ArgR arch))
forall arch a.
Marshal arch =>
ParamR arch a -> a -> Par arch (DList (ArgR arch))
marshalParam' @arch)

{-# INLINE marshalTupR' #-}
marshalTupR' :: forall arch s a. Marshal arch => (forall b. s b -> b -> Par arch (DList (ArgR arch))) -> TupR s a -> a -> Par arch (DList (ArgR arch))
marshalTupR' :: (forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s a -> a -> Par arch (DList (ArgR arch))
marshalTupR' forall b. s b -> b -> Par arch (DList (ArgR arch))
_ TupR s a
TupRunit         ()       = DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall (m :: * -> *) a. Monad m => a -> m a
return (DList (ArgR arch) -> Par arch (DList (ArgR arch)))
-> DList (ArgR arch) -> Par arch (DList (ArgR arch))
forall a b. (a -> b) -> a -> b
$ DList (ArgR arch)
forall a. DList a
DL.empty
marshalTupR' forall b. s b -> b -> Par arch (DList (ArgR arch))
f (TupRsingle s a
t)   a
x        = s a -> a -> Par arch (DList (ArgR arch))
forall b. s b -> b -> Par arch (DList (ArgR arch))
f s a
t a
x
marshalTupR' forall b. s b -> b -> Par arch (DList (ArgR arch))
f (TupRpair TupR s a1
t1 TupR s b
t2) (x1, x2) = DList (ArgR arch) -> DList (ArgR arch) -> DList (ArgR arch)
forall a. DList a -> DList a -> DList a
DL.append (DList (ArgR arch) -> DList (ArgR arch) -> DList (ArgR arch))
-> Par arch (DList (ArgR arch))
-> Par arch (DList (ArgR arch) -> DList (ArgR arch))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s a1 -> a1 -> Par arch (DList (ArgR arch))
forall arch (s :: * -> *) a.
Marshal arch =>
(forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s a -> a -> Par arch (DList (ArgR arch))
marshalTupR' @arch forall b. s b -> b -> Par arch (DList (ArgR arch))
f TupR s a1
t1 a1
x1 Par arch (DList (ArgR arch) -> DList (ArgR arch))
-> Par arch (DList (ArgR arch)) -> Par arch (DList (ArgR arch))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s b -> b -> Par arch (DList (ArgR arch))
forall arch (s :: * -> *) a.
Marshal arch =>
(forall b. s b -> b -> Par arch (DList (ArgR arch)))
-> TupR s a -> a -> Par arch (DList (ArgR arch))
marshalTupR' @arch forall b. s b -> b -> Par arch (DList (ArgR arch))
f TupR s b
t2 b
x2