{-# LANGUAGE CPP #-}
module Data.Array.Accelerate.Numeric.Sum.LLVM.PTX (
fadd, fsub, fmul,
) where
import Data.Array.Accelerate as A
import Data.Array.Accelerate.Type
#ifdef ACCELERATE_LLVM_PTX_BACKEND
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.PTX.Foreign as A
import qualified Data.Array.Accelerate.Numeric.Sum.LLVM.Prim as Prim
#endif
#ifdef ACCELERATE_LLVM_PTX_BACKEND
wrap2 :: (Elt a, Elt b, Elt c)
=> String
-> IRFun1 PTX () ((a, b) -> c)
-> (Exp a -> Exp b -> Exp c)
-> Exp a
-> Exp b
-> Exp c
wrap2 str f g = A.curry (foreignExp (ForeignExp str f) (A.uncurry g))
#endif
fadd :: (IsFloating a, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Exp a -> Exp a
#ifdef ACCELERATE_LLVM_PTX_BACKEND
fadd = wrap2 "fadd" (Prim.fadd floatingType)
#else
fadd = id
#endif
fsub :: (IsFloating a, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Exp a -> Exp a
#ifdef ACCELERATE_LLVM_PTX_BACKEND
fsub = wrap2 "fsub" (Prim.fsub floatingType)
#else
fsub = id
#endif
fmul :: (IsFloating a, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Exp a -> Exp a
#ifdef ACCELERATE_LLVM_PTX_BACKEND
fmul = wrap2 "fmul" (Prim.fmul floatingType)
#else
fmul = id
#endif