{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Permute (
IRPermuteFun(..),
llvmOfPermuteFun,
atomicCAS_rmw,
atomicCAS_cmp,
) where
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar hiding ( Foreign )
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.CodeGen.Type
import Data.Array.Accelerate.LLVM.Foreign
import LLVM.AST.Type.AddrSpace
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Atomic
import LLVM.AST.Type.Instruction.RMW as RMW
import LLVM.AST.Type.Instruction.Volatile
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import Control.Applicative
import Prelude
data IRPermuteFun arch aenv t where
IRPermuteFun :: { combine :: IRFun2 arch aenv (e -> e -> e)
, atomicRMW :: Maybe
( RMWOperation
, IRFun1 arch aenv (e -> e)
)
}
-> IRPermuteFun arch aenv (e -> e -> e)
llvmOfPermuteFun
:: forall arch aenv e. Foreign arch
=> arch
-> DelayedFun aenv (e -> e -> e)
-> Gamma aenv
-> IRPermuteFun arch aenv (e -> e -> e)
llvmOfPermuteFun arch fun aenv = IRPermuteFun{..}
where
combine = llvmOfFun2 arch fun aenv
atomicRMW
| Lam (Lam (Body body)) <- fun
, SingleTuple{} <- eltType (undefined::e)
, Just body' <- strengthenE latest body
, fun' <- llvmOfFun1 arch (Lam (Body body')) aenv
= Just (Exchange, fun')
| Lam (Lam (Body body)) <- fun
, SingleTuple{} <- eltType (undefined::e)
, Just (rmw, x) <- rmwOp body
, Just x' <- strengthenE latest x
, fun' <- llvmOfFun1 arch (Lam (Body x')) aenv
= Just (rmw, fun')
| otherwise
= Nothing
rmwOp :: DelayedOpenExp (((),e),e) aenv e -> Maybe (RMWOperation, DelayedOpenExp (((),e),e) aenv e)
rmwOp (PrimApp f xs)
| PrimAdd{} <- f = (RMW.Add,) <$> extract xs
| PrimSub{} <- f = (RMW.Sub,) <$> extract xs
| PrimMin{} <- f = (RMW.Min,) <$> extract xs
| PrimMax{} <- f = (RMW.Max,) <$> extract xs
| PrimBOr{} <- f = (RMW.Or,) <$> extract xs
| PrimBAnd{} <- f = (RMW.And,) <$> extract xs
| PrimBXor{} <- f = (RMW.Xor,) <$> extract xs
rmwOp _ = Nothing
extract :: DelayedOpenExp (((),e),e) aenv (e,e) -> Maybe (DelayedOpenExp (((),e),e) aenv e)
extract (Tuple (SnocTup (SnocTup NilTup x) y))
| Just Refl <- match x (Var ZeroIdx) = Just y
| Just Refl <- match y (Var ZeroIdx) = Just x
extract _
= Nothing
latest :: (((),e),e) :?> ((),e)
latest ZeroIdx = Nothing
latest (SuccIdx ix) = Just ix
atomicCAS_rmw
:: ScalarType t
-> (IR t -> CodeGen (IR t))
-> Operand (Ptr t)
-> CodeGen ()
atomicCAS_rmw t update addr =
case t of
NonNumScalarType s -> nonnum s
NumScalarType (FloatingNumType f) -> floating f
NumScalarType (IntegralNumType i) -> integral i
where
nonnum :: NonNumType t -> CodeGen ()
nonnum TypeBool{} = atomicCAS_rmw' t (integralType :: IntegralType Word8) update addr
nonnum TypeChar{} = atomicCAS_rmw' t (integralType :: IntegralType Word32) update addr
nonnum TypeCChar{} = atomicCAS_rmw' t (integralType :: IntegralType Word8) update addr
nonnum TypeCSChar{} = atomicCAS_rmw' t (integralType :: IntegralType Word8) update addr
nonnum TypeCUChar{} = atomicCAS_rmw' t (integralType :: IntegralType Word8) update addr
floating :: FloatingType t -> CodeGen ()
floating TypeFloat{} = atomicCAS_rmw' t (integralType :: IntegralType Word32) update addr
floating TypeDouble{} = atomicCAS_rmw' t (integralType :: IntegralType Word64) update addr
floating TypeCFloat{} = atomicCAS_rmw' t (integralType :: IntegralType Word32) update addr
floating TypeCDouble{} = atomicCAS_rmw' t (integralType :: IntegralType Word64) update addr
integral :: IntegralType t -> CodeGen ()
integral i = atomicCAS_rmw' t i update addr
atomicCAS_rmw'
:: ScalarType t
-> IntegralType i
-> (IR t -> CodeGen (IR t))
-> Operand (Ptr t)
-> CodeGen ()
atomicCAS_rmw' t i update addr | EltDict <- integralElt i = do
let si = NumScalarType (IntegralNumType i)
spin <- newBlock "rmw.spin"
exit <- newBlock "rmw.exit"
addr' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType si) defaultAddrSpace) addr
init' <- instr' $ Load si NonVolatile addr'
old' <- fresh
top <- br spin
setBlock spin
old <- instr' $ BitCast t (op i old')
val <- update (ir t old)
val' <- instr' $ BitCast si (op t val)
r <- instr' $ CmpXchg i NonVolatile addr' (op i old') val' (CrossThread, AcquireRelease) Monotonic
done <- instr' $ ExtractValue scalarType ZeroTupIdx r
next' <- instr' $ ExtractValue si (SuccTupIdx ZeroTupIdx) r
bot <- cbr (ir scalarType done) exit spin
_ <- phi' spin old' [(ir i init',top), (ir i next',bot)]
setBlock exit
atomicCAS_cmp
:: ScalarType t
-> (ScalarType t -> IR t -> IR t -> CodeGen (IR Bool))
-> Operand (Ptr t)
-> Operand t
-> CodeGen ()
atomicCAS_cmp t cmp addr val =
case t of
NonNumScalarType s -> nonnum s
NumScalarType (FloatingNumType f) -> floating f
NumScalarType (IntegralNumType i) -> integral i
where
nonnum :: NonNumType t -> CodeGen ()
nonnum TypeBool{} = atomicCAS_cmp' t (integralType :: IntegralType Word8) cmp addr val
nonnum TypeChar{} = atomicCAS_cmp' t (integralType :: IntegralType Word32) cmp addr val
nonnum TypeCChar{} = atomicCAS_cmp' t (integralType :: IntegralType Word8) cmp addr val
nonnum TypeCSChar{} = atomicCAS_cmp' t (integralType :: IntegralType Word8) cmp addr val
nonnum TypeCUChar{} = atomicCAS_cmp' t (integralType :: IntegralType Word8) cmp addr val
floating :: FloatingType t -> CodeGen ()
floating TypeFloat{} = atomicCAS_cmp' t (integralType :: IntegralType Word32) cmp addr val
floating TypeDouble{} = atomicCAS_cmp' t (integralType :: IntegralType Word64) cmp addr val
floating TypeCFloat{} = atomicCAS_cmp' t (integralType :: IntegralType Word32) cmp addr val
floating TypeCDouble{} = atomicCAS_cmp' t (integralType :: IntegralType Word64) cmp addr val
integral :: IntegralType t -> CodeGen ()
integral i = atomicCAS_cmp' t i cmp addr val
atomicCAS_cmp'
:: ScalarType t
-> IntegralType i
-> (ScalarType t -> IR t -> IR t -> CodeGen (IR Bool))
-> Operand (Ptr t)
-> Operand t
-> CodeGen ()
atomicCAS_cmp' t i cmp addr val | EltDict <- scalarElt t = do
let si = NumScalarType (IntegralNumType i)
test <- newBlock "cas.cmp"
spin <- newBlock "cas.retry"
exit <- newBlock "cas.exit"
addr' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType si) defaultAddrSpace) addr
val' <- instr' $ BitCast si val
old <- fresh
start <- instr' $ Load t NonVolatile addr
top <- br test
setBlock test
yes <- cmp t (ir t val) old
_ <- cbr yes spin exit
setBlock spin
old' <- instr' $ BitCast si (op t old)
r <- instr' $ CmpXchg i NonVolatile addr' old' val' (CrossThread, AcquireRelease) Monotonic
done <- instr' $ ExtractValue scalarType ZeroTupIdx r
next <- instr' $ ExtractValue si (SuccTupIdx ZeroTupIdx) r
next' <- instr' $ BitCast t next
bot <- cbr (ir scalarType done) exit test
_ <- phi' test old [(ir t start,top), (ir t next',bot)]
setBlock exit