{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.Array.Accelerate.LLVM.PTX.CodeGen.Transform
where
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Base
import Data.Array.Accelerate.LLVM.PTX.CodeGen.Loop
import Data.Array.Accelerate.LLVM.PTX.Target ( PTX )
mkTransform
:: Gamma aenv
-> ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> IRFun1 PTX aenv (sh' -> sh)
-> IRFun1 PTX aenv (a -> b)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' b))
mkTransform :: Gamma aenv
-> ArrayR (Array sh a)
-> ArrayR (Array sh' b)
-> IRFun1 PTX aenv (sh' -> sh)
-> IRFun1 PTX aenv (a -> b)
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' b))
mkTransform Gamma aenv
aenv repr :: ArrayR (Array sh a)
repr@(ArrayR ShapeR sh
shr TypeR e
_) repr' :: ArrayR (Array sh' b)
repr'@(ArrayR ShapeR sh
shr' TypeR e
_) IRFun1 PTX aenv (sh' -> sh)
p IRFun1 PTX aenv (a -> b)
f =
let
(IRArray (Array sh' b)
arrOut, [Parameter]
paramOut) = ArrayR (Array sh' b)
-> Name (Array sh' b) -> (IRArray (Array sh' b), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh' b)
repr' Name (Array sh' b)
"out"
(IRArray (Array sh a)
arrIn, [Parameter]
paramIn) = ArrayR (Array sh a)
-> Name (Array sh a) -> (IRArray (Array sh a), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh a)
repr Name (Array sh a)
"in"
paramEnv :: [Parameter]
paramEnv = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
in
Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv (Array sh' b))
forall aenv a.
Label
-> [Parameter]
-> CodeGen PTX ()
-> CodeGen PTX (IROpenAcc PTX aenv a)
makeOpenAcc Label
"transform" ([Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' b)))
-> CodeGen PTX () -> CodeGen PTX (IROpenAcc PTX aenv (Array sh' b))
forall a b. (a -> b) -> a -> b
$ do
let start :: Operands Int
start = Int -> Operands Int
liftInt Int
0
Operands Int
end <- ShapeR sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh -> Operands sh -> CodeGen arch (Operands Int)
shapeSize ShapeR sh
shr' (IRArray (Array sh' b) -> Operands sh'
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh' b)
arrOut)
Operands Int
-> Operands Int
-> (Operands Int -> CodeGen PTX ())
-> CodeGen PTX ()
imapFromTo Operands Int
start Operands Int
end ((Operands Int -> CodeGen PTX ()) -> CodeGen PTX ())
-> (Operands Int -> CodeGen PTX ()) -> CodeGen PTX ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i' -> do
Operands sh
ix' <- ShapeR sh
-> Operands sh -> Operands Int -> CodeGen PTX (Operands sh)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands Int -> CodeGen arch (Operands sh)
indexOfInt ShapeR sh
shr' (IRArray (Array sh' b) -> Operands sh'
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh' b)
arrOut) Operands Int
i'
Operands sh
ix <- IRFun1 PTX aenv (sh' -> sh)
-> Operands sh' -> IROpenExp PTX ((), sh') aenv sh
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 PTX aenv (sh' -> sh)
p Operands sh'
Operands sh
ix'
Operands Int
i <- ShapeR sh
-> Operands sh -> Operands sh -> CodeGen PTX (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex ShapeR sh
shr (IRArray (Array sh a) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh a)
arrIn) Operands sh
Operands sh
ix
Operands a
a <- IntegralType Int
-> IRArray (Array sh a) -> Operands Int -> CodeGen PTX (Operands a)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Array sh a)
arrIn Operands Int
i
Operands b
b <- IRFun1 PTX aenv (a -> b)
-> Operands a -> IROpenExp PTX ((), a) aenv b
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 PTX aenv (a -> b)
f Operands a
a
IntegralType Int
-> IRArray (Array sh' b)
-> Operands Int
-> Operands b
-> CodeGen PTX ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh' b)
arrOut Operands Int
i' Operands b
b
CodeGen PTX ()
forall arch. HasCallStack => CodeGen arch ()
return_