{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.Numeric.Sum.LLVM.Prim (
fadd, fsub, fmul,
) where
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.LLVM.CodeGen.Downcast ( downcast )
import Data.Array.Accelerate.LLVM.CodeGen.IR ( IR(..), Operands(..), IROP(..) )
import Data.Array.Accelerate.LLVM.CodeGen.Monad ( CodeGen, freshName, instr_ )
import Data.Array.Accelerate.LLVM.CodeGen.Sugar ( IROpenFun1(..) )
import qualified Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import qualified LLVM.AST.Type.Name as A
import qualified LLVM.AST.Type.Operand as A
import qualified LLVM.AST.Type.Representation as A
import LLVM.AST.Instruction
import LLVM.AST.Name
import LLVM.AST.Operand
import LLVM.AST.Type
fadd :: FloatingType a -> IROpenFun1 arch env aenv ((a,a) -> a)
fadd t = IRFun1 $ A.uncurry (binop FAdd t)
fsub :: FloatingType a -> IROpenFun1 arch env aenv ((a,a) -> a)
fsub t = IRFun1 $ A.uncurry (binop FSub t)
fmul :: FloatingType a -> IROpenFun1 arch env aenv ((a,a) -> a)
fmul t = IRFun1 $ A.uncurry (binop FMul t)
binop :: (FastMathFlags -> Operand -> Operand -> InstructionMetadata -> Instruction) -> FloatingType a -> IR a -> IR a -> CodeGen (IR a)
binop f t (op t -> x) (op t -> y) = do
r <- instr (downcast t) (f fmf (downcast x) (downcast y) md)
return (upcast t r)
md :: InstructionMetadata
md = []
fmf :: FastMathFlags
#if MIN_VERSION_llvm_hs_pure(6,0,0)
fmf = noFastMathFlags
#else
fmf = NoFastMathFlags
#endif
fresh :: CodeGen Name
fresh = downcast <$> freshName
instr :: Type -> Instruction -> CodeGen Operand
instr ty ins = do
name <- fresh
instr_ (name := ins)
return (LocalReference ty name)
upcast :: FloatingType t -> Operand -> IR t
upcast TypeFloat{} (LocalReference (FloatingPointType FloatFP) (UnName x)) = IR $ OP_Float (A.LocalReference A.type' (A.UnName x))
upcast TypeDouble{} (LocalReference (FloatingPointType DoubleFP) (UnName x)) = IR $ OP_Double (A.LocalReference A.type' (A.UnName x))
upcast TypeCFloat{} (LocalReference (FloatingPointType FloatFP) (UnName x)) = IR $ OP_CFloat (A.LocalReference A.type' (A.UnName x))
upcast TypeCDouble{} (LocalReference (FloatingPointType DoubleFP) (UnName x)) = IR $ OP_CDouble (A.LocalReference A.type' (A.UnName x))
upcast _ _ = $internalError "upcast" "expected local reference"