{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE Trustworthy #-}
module Futhark.CodeGen.Backends.SimpleRep
( tupleField
, funName
, defaultMemBlockType
, primTypeToCType
, signedPrimTypeToCType
, cIntOps
, cFloat32Ops, cFloat32Funs
, cFloat64Ops, cFloat64Funs
, cFloatConvOps
)
where
import qualified Language.C.Syntax as C
import qualified Language.C.Quote.C as C
import Futhark.CodeGen.ImpCode
import Futhark.Util.Pretty (prettyOneLine)
import Futhark.Util (zEncodeString)
intTypeToCType :: IntType -> C.Type
intTypeToCType :: IntType -> Type
intTypeToCType IntType
Int8 = [C.cty|typename int8_t|]
intTypeToCType IntType
Int16 = [C.cty|typename int16_t|]
intTypeToCType IntType
Int32 = [C.cty|typename int32_t|]
intTypeToCType IntType
Int64 = [C.cty|typename int64_t|]
uintTypeToCType :: IntType -> C.Type
uintTypeToCType :: IntType -> Type
uintTypeToCType IntType
Int8 = [C.cty|typename uint8_t|]
uintTypeToCType IntType
Int16 = [C.cty|typename uint16_t|]
uintTypeToCType IntType
Int32 = [C.cty|typename uint32_t|]
uintTypeToCType IntType
Int64 = [C.cty|typename uint64_t|]
floatTypeToCType :: FloatType -> C.Type
floatTypeToCType :: FloatType -> Type
floatTypeToCType FloatType
Float32 = [C.cty|float|]
floatTypeToCType FloatType
Float64 = [C.cty|double|]
primTypeToCType :: PrimType -> C.Type
primTypeToCType :: PrimType -> Type
primTypeToCType (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primTypeToCType (FloatType FloatType
t) = FloatType -> Type
floatTypeToCType FloatType
t
primTypeToCType PrimType
Bool = [C.cty|typename bool|]
primTypeToCType PrimType
Cert = [C.cty|typename bool|]
signedPrimTypeToCType :: Signedness -> PrimType -> C.Type
signedPrimTypeToCType :: Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
TypeUnsigned (IntType IntType
t) = IntType -> Type
uintTypeToCType IntType
t
signedPrimTypeToCType Signedness
TypeDirect (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
signedPrimTypeToCType Signedness
_ PrimType
t = PrimType -> Type
primTypeToCType PrimType
t
tupleField :: Int -> String
tupleField :: Int -> [Char]
tupleField Int
i = [Char]
"v" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i
funName :: Name -> String
funName :: Name -> [Char]
funName = ([Char]
"futrts_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++) ([Char] -> [Char]) -> (Name -> [Char]) -> Name -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [Char]
zEncodeString ([Char] -> [Char]) -> (Name -> [Char]) -> Name -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> [Char]
nameToString
funName' :: String -> String
funName' :: [Char] -> [Char]
funName' = Name -> [Char]
funName (Name -> [Char]) -> ([Char] -> Name) -> [Char] -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Name
nameFromString
defaultMemBlockType :: C.Type
defaultMemBlockType :: Type
defaultMemBlockType = [C.cty|char*|]
cIntOps :: [C.Definition]
cIntOps :: [Definition]
cIntOps = ((IntType -> Definition) -> [Definition])
-> [IntType -> Definition] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((IntType -> Definition) -> [IntType] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
`map` [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound]) [IntType -> Definition]
ops
[Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cIntPrimFuns
where ops :: [IntType -> Definition]
ops = [IntType -> Definition
mkAdd, IntType -> Definition
mkSub, IntType -> Definition
mkMul,
IntType -> Definition
mkUDiv, IntType -> Definition
mkUDivUp, IntType -> Definition
mkUMod, IntType -> Definition
mkUDivSafe, IntType -> Definition
mkUDivUpSafe, IntType -> Definition
mkUModSafe,
IntType -> Definition
mkSDiv, IntType -> Definition
mkSDivUp, IntType -> Definition
mkSMod, IntType -> Definition
mkSDivSafe, IntType -> Definition
mkSDivUpSafe, IntType -> Definition
mkSModSafe,
IntType -> Definition
mkSQuot, IntType -> Definition
mkSRem, IntType -> Definition
mkSQuotSafe, IntType -> Definition
mkSRemSafe,
IntType -> Definition
mkSMin, IntType -> Definition
mkUMin,
IntType -> Definition
mkSMax, IntType -> Definition
mkUMax,
IntType -> Definition
mkShl, IntType -> Definition
mkLShr, IntType -> Definition
mkAShr,
IntType -> Definition
mkAnd, IntType -> Definition
mkOr, IntType -> Definition
mkXor,
IntType -> Definition
mkUlt, IntType -> Definition
mkUle, IntType -> Definition
mkSlt, IntType -> Definition
mkSle,
IntType -> Definition
mkPow,
IntType -> Definition
mkIToB, IntType -> Definition
mkBToI
] [IntType -> Definition]
-> [IntType -> Definition] -> [IntType -> Definition]
forall a. [a] -> [a] -> [a]
++
(IntType -> IntType -> Definition)
-> [IntType] -> [IntType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> IntType -> Definition
mkSExt [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound] [IntType -> Definition]
-> [IntType -> Definition] -> [IntType -> Definition]
forall a. [a] -> [a] -> [a]
++
(IntType -> IntType -> Definition)
-> [IntType] -> [IntType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> IntType -> Definition
mkZExt [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound]
taggedI :: [Char] -> IntType -> [Char]
taggedI [Char]
s IntType
Int8 = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"8"
taggedI [Char]
s IntType
Int16 = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"16"
taggedI [Char]
s IntType
Int32 = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"32"
taggedI [Char]
s IntType
Int64 = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"64"
mkAdd :: IntType -> Definition
mkAdd = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"add" [C.cexp|x + y|]
mkSub :: IntType -> Definition
mkSub = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"sub" [C.cexp|x - y|]
mkMul :: IntType -> Definition
mkMul = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"mul" [C.cexp|x * y|]
mkUDiv :: IntType -> Definition
mkUDiv = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"udiv" [C.cexp|x / y|]
mkUDivUp :: IntType -> Definition
mkUDivUp = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"udiv_up" [C.cexp|(x+y-1) / y|]
mkUMod :: IntType -> Definition
mkUMod = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"umod" [C.cexp|x % y|]
mkUDivSafe :: IntType -> Definition
mkUDivSafe = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"udiv_safe" [C.cexp|y == 0 ? 0 : x / y|]
mkUDivUpSafe :: IntType -> Definition
mkUDivUpSafe = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"udiv_up_safe" [C.cexp|y == 0 ? 0 : (x+y-1) / y|]
mkUModSafe :: IntType -> Definition
mkUModSafe = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"umod_safe" [C.cexp|y == 0 ? 0 : x % y|]
mkUMax :: IntType -> Definition
mkUMax = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"umax" [C.cexp|x < y ? y : x|]
mkUMin :: IntType -> Definition
mkUMin = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"umin" [C.cexp|x < y ? x : y|]
mkSDiv :: IntType -> Definition
mkSDiv IntType
t =
let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
in [C.cedecl|static inline $ty:ct $id:(taggedI "sdiv" t)($ty:ct x, $ty:ct y) {
$ty:ct q = x / y;
$ty:ct r = x % y;
return q -
(((r != 0) && ((r < 0) != (y < 0))) ? 1 : 0);
}|]
mkSDivUp :: IntType -> Definition
mkSDivUp IntType
t =
[Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"sdiv_up" [C.cexp|$id:(taggedI "sdiv" t)(x+y-1,y)|] IntType
t
mkSMod :: IntType -> Definition
mkSMod IntType
t =
let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
in [C.cedecl|static inline $ty:ct $id:(taggedI "smod" t)($ty:ct x, $ty:ct y) {
$ty:ct r = x % y;
return r +
((r == 0 || (x > 0 && y > 0) || (x < 0 && y < 0)) ? 0 : y);
}|]
mkSDivSafe :: IntType -> Definition
mkSDivSafe IntType
t =
[Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"sdiv_safe" [C.cexp|y == 0 ? 0 : $id:(taggedI "sdiv" t)(x,y)|] IntType
t
mkSDivUpSafe :: IntType -> Definition
mkSDivUpSafe IntType
t =
[Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"sdiv_up_safe" [C.cexp|$id:(taggedI "sdiv_safe" t)(x+y-1,y)|] IntType
t
mkSModSafe :: IntType -> Definition
mkSModSafe IntType
t =
[Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"smod_safe" [C.cexp|y == 0 ? 0 : $id:(taggedI "smod" t)(x,y)|] IntType
t
mkSQuot :: IntType -> Definition
mkSQuot = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"squot" [C.cexp|x / y|]
mkSRem :: IntType -> Definition
mkSRem = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"srem" [C.cexp|x % y|]
mkSQuotSafe :: IntType -> Definition
mkSQuotSafe = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"squot_safe" [C.cexp|y == 0 ? 0 : x / y|]
mkSRemSafe :: IntType -> Definition
mkSRemSafe = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"srem_safe" [C.cexp|y == 0 ? 0 : x % y|]
mkSMax :: IntType -> Definition
mkSMax = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"smax" [C.cexp|x < y ? y : x|]
mkSMin :: IntType -> Definition
mkSMin = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"smin" [C.cexp|x < y ? x : y|]
mkShl :: IntType -> Definition
mkShl = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"shl" [C.cexp|x << y|]
mkLShr :: IntType -> Definition
mkLShr = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"lshr" [C.cexp|x >> y|]
mkAShr :: IntType -> Definition
mkAShr = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
"ashr" [C.cexp|x >> y|]
mkAnd :: IntType -> Definition
mkAnd = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"and" [C.cexp|x & y|]
mkOr :: IntType -> Definition
mkOr = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"or" [C.cexp|x | y|]
mkXor :: IntType -> Definition
mkXor = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
"xor" [C.cexp|x ^ y|]
mkUlt :: IntType -> Definition
mkUlt = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
uintCmpOp [Char]
"ult" [C.cexp|x < y|]
mkUle :: IntType -> Definition
mkUle = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
uintCmpOp [Char]
"ule" [C.cexp|x <= y|]
mkSlt :: IntType -> Definition
mkSlt = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
intCmpOp [Char]
"slt" [C.cexp|x < y|]
mkSle :: IntType -> Definition
mkSle = [Char] -> Exp -> IntType -> Definition
forall a. ToExp a => [Char] -> a -> IntType -> Definition
intCmpOp [Char]
"sle" [C.cexp|x <= y|]
macro :: [Char] -> a -> Definition
macro [Char]
name a
rhs =
[C.cedecl|$esc:("#define " ++ name ++ "(x) (" ++ prettyOneLine rhs ++ ")")|]
mkPow :: IntType -> Definition
mkPow IntType
t =
let ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
in [C.cedecl|static inline $ty:ct $id:(taggedI "pow" t)($ty:ct x, $ty:ct y) {
$ty:ct res = 1, rem = y;
while (rem != 0) {
if (rem & 1) {
res *= x;
}
rem >>= 1;
x *= x;
}
return res;
}|]
mkSExt :: IntType -> IntType -> Definition
mkSExt IntType
from_t IntType
to_t = [Char] -> Exp -> Definition
forall a. Pretty a => [Char] -> a -> Definition
macro [Char]
name [C.cexp|($ty:to_ct)(($ty:from_ct)x)|]
where name :: [Char]
name = [Char]
"sext_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++IntType -> [Char]
forall a. Pretty a => a -> [Char]
pretty IntType
from_t[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++IntType -> [Char]
forall a. Pretty a => a -> [Char]
pretty IntType
to_t
from_ct :: Type
from_ct = IntType -> Type
intTypeToCType IntType
from_t
to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t
mkZExt :: IntType -> IntType -> Definition
mkZExt IntType
from_t IntType
to_t = [Char] -> Exp -> Definition
forall a. Pretty a => [Char] -> a -> Definition
macro [Char]
name [C.cexp|($ty:to_ct)(($ty:from_ct)x)|]
where name :: [Char]
name = [Char]
"zext_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++IntType -> [Char]
forall a. Pretty a => a -> [Char]
pretty IntType
from_t[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++IntType -> [Char]
forall a. Pretty a => a -> [Char]
pretty IntType
to_t
from_ct :: Type
from_ct = IntType -> Type
uintTypeToCType IntType
from_t
to_ct :: Type
to_ct = IntType -> Type
uintTypeToCType IntType
to_t
mkBToI :: IntType -> Definition
mkBToI IntType
to_t =
[C.cedecl|static inline $ty:to_ct
$id:name($ty:from_ct x) { return x; } |]
where name :: [Char]
name = [Char]
"btoi_bool_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++IntType -> [Char]
forall a. Pretty a => a -> [Char]
pretty IntType
to_t
from_ct :: Type
from_ct = PrimType -> Type
primTypeToCType PrimType
Bool
to_ct :: Type
to_ct = IntType -> Type
intTypeToCType IntType
to_t
mkIToB :: IntType -> Definition
mkIToB IntType
from_t =
[C.cedecl|static inline $ty:to_ct
$id:name($ty:from_ct x) { return x; } |]
where name :: [Char]
name = [Char]
"itob_"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++IntType -> [Char]
forall a. Pretty a => a -> [Char]
pretty IntType
from_t[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"_bool"
to_ct :: Type
to_ct = PrimType -> Type
primTypeToCType PrimType
Bool
from_ct :: Type
from_ct = IntType -> Type
intTypeToCType IntType
from_t
simpleUintOp :: [Char] -> a -> IntType -> Definition
simpleUintOp [Char]
s a
e IntType
t =
[C.cedecl|static inline $ty:ct $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where ct :: Type
ct = IntType -> Type
uintTypeToCType IntType
t
simpleIntOp :: [Char] -> a -> IntType -> Definition
simpleIntOp [Char]
s a
e IntType
t =
[C.cedecl|static inline $ty:ct $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
intCmpOp :: [Char] -> a -> IntType -> Definition
intCmpOp [Char]
s a
e IntType
t =
[C.cedecl|static inline typename bool $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where ct :: Type
ct = IntType -> Type
intTypeToCType IntType
t
uintCmpOp :: [Char] -> a -> IntType -> Definition
uintCmpOp [Char]
s a
e IntType
t =
[C.cedecl|static inline typename bool $id:(taggedI s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where ct :: Type
ct = IntType -> Type
uintTypeToCType IntType
t
cIntPrimFuns :: [C.Definition]
cIntPrimFuns :: [Definition]
cIntPrimFuns =
[C.cunit|
$esc:("#if defined(__OPENCL_VERSION__)")
static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
return popcount(x);
}
static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
return popcount(x);
}
static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
return popcount(x);
}
static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
return popcount(x);
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
return __popc(zext_i8_i32(x));
}
static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
return __popc(zext_i16_i32(x));
}
static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
return __popc(x);
}
static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
return __popcll(x);
}
$esc:("#else")
static typename int32_t $id:(funName' "popc8") (typename int8_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
static typename int32_t $id:(funName' "popc16") (typename int16_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
static typename int32_t $id:(funName' "popc32") (typename int32_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
static typename int32_t $id:(funName' "popc64") (typename int64_t x) {
int c = 0;
for (; x; ++c) {
x &= x - 1;
}
return c;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
return mul_hi(a, b);
}
static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
return mul_hi(a, b);
}
static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
return mul_hi(a, b);
}
static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
return mul_hi(a, b);
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
typename uint16_t aa = a;
typename uint16_t bb = b;
return (aa * bb) >> 8;
}
static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
typename uint32_t aa = a;
typename uint32_t bb = b;
return (aa * bb) >> 16;
}
static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
return mulhi(a, b);
}
static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
return mul64hi(a, b);
}
$esc:("#else")
static typename uint8_t $id:(funName' "mul_hi8") (typename uint8_t a, typename uint8_t b) {
typename uint16_t aa = a;
typename uint16_t bb = b;
return (aa * bb) >> 8;
}
static typename uint16_t $id:(funName' "mul_hi16") (typename uint16_t a, typename uint16_t b) {
typename uint32_t aa = a;
typename uint32_t bb = b;
return (aa * bb) >> 16;
}
static typename uint32_t $id:(funName' "mul_hi32") (typename uint32_t a, typename uint32_t b) {
typename uint64_t aa = a;
typename uint64_t bb = b;
return (aa * bb) >> 32;
}
static typename uint64_t $id:(funName' "mul_hi64") (typename uint64_t a, typename uint64_t b) {
typename __uint128_t aa = a;
typename __uint128_t bb = b;
return (aa * bb) >> 64;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
static typename uint8_t $id:(funName' "mad_hi8") (typename uint8_t a, typename uint8_t b, typename uint8_t c) {
return mad_hi(a, b, c);
}
static typename uint16_t $id:(funName' "mad_hi16") (typename uint16_t a, typename uint16_t b, typename uint16_t c) {
return mad_hi(a, b, c);
}
static typename uint32_t $id:(funName' "mad_hi32") (typename uint32_t a, typename uint32_t b, typename uint32_t c) {
return mad_hi(a, b, c);
}
static typename uint64_t $id:(funName' "mad_hi64") (typename uint64_t a, typename uint64_t b, typename uint64_t c) {
return mad_hi(a, b, c);
}
$esc:("#else")
static typename uint8_t $id:(funName' "mad_hi8") (typename uint8_t a, typename uint8_t b, typename uint8_t c) {
return futrts_mul_hi8(a, b) + c;
}
static typename uint16_t $id:(funName' "mad_hi16") (typename uint16_t a, typename uint16_t b, typename uint16_t c) {
return futrts_mul_hi16(a, b) + c;
}
static typename uint32_t $id:(funName' "mad_hi32") (typename uint32_t a, typename uint32_t b, typename uint32_t c) {
return futrts_mul_hi32(a, b) + c;
}
static typename uint64_t $id:(funName' "mad_hi64") (typename uint64_t a, typename uint64_t b, typename uint64_t c) {
return futrts_mul_hi64(a, b) + c;
}
$esc:("#endif")
$esc:("#if defined(__OPENCL_VERSION__)")
static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
return clz(x);
}
static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
return clz(x);
}
static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
return clz(x);
}
static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
return clz(x);
}
$esc:("#elif defined(__CUDA_ARCH__)")
static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
return __clz(zext_i8_i32(x))-24;
}
static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
return __clz(zext_i16_i32(x))-16;
}
static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
return __clz(x);
}
static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
return __clzll(x);
}
$esc:("#else")
static typename int32_t $id:(funName' "clz8") (typename int8_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
static typename int32_t $id:(funName' "clz16") (typename int16_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
static typename int32_t $id:(funName' "clz32") (typename int32_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
static typename int32_t $id:(funName' "clz64") (typename int64_t x) {
int n = 0;
int bits = sizeof(x) * 8;
for (int i = 0; i < bits; i++) {
if (x < 0) break;
n++;
x <<= 1;
}
return n;
}
$esc:("#endif")
|]
cFloat32Ops :: [C.Definition]
cFloat64Ops :: [C.Definition]
cFloatConvOps :: [C.Definition]
([Definition]
cFloat32Ops, [Definition]
cFloat64Ops, [Definition]
cFloatConvOps) =
( ((FloatType -> Definition) -> Definition)
-> [FloatType -> Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> Definition) -> FloatType -> Definition
forall a b. (a -> b) -> a -> b
$FloatType
Float32) [FloatType -> Definition]
mkOps
, ((FloatType -> Definition) -> Definition)
-> [FloatType -> Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> Definition) -> FloatType -> Definition
forall a b. (a -> b) -> a -> b
$FloatType
Float64) [FloatType -> Definition]
mkOps
, [ [Char] -> FloatType -> FloatType -> Definition
mkFPConvFF [Char]
"fpconv" FloatType
from FloatType
to |
FloatType
from <- [FloatType
forall a. Bounded a => a
minBound..FloatType
forall a. Bounded a => a
maxBound],
FloatType
to <- [FloatType
forall a. Bounded a => a
minBound..FloatType
forall a. Bounded a => a
maxBound] ])
where taggedF :: [Char] -> FloatType -> [Char]
taggedF [Char]
s FloatType
Float32 = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"32"
taggedF [Char]
s FloatType
Float64 = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"64"
convOp :: [Char] -> a -> a -> [Char]
convOp [Char]
s a
from a
to = [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Pretty a => a -> [Char]
pretty a
from [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ a -> [Char]
forall a. Pretty a => a -> [Char]
pretty a
to
mkOps :: [FloatType -> Definition]
mkOps = [FloatType -> Definition
mkFDiv, FloatType -> Definition
mkFAdd, FloatType -> Definition
mkFSub, FloatType -> Definition
mkFMul, FloatType -> Definition
mkFMin, FloatType -> Definition
mkFMax, FloatType -> Definition
mkPow, FloatType -> Definition
mkCmpLt, FloatType -> Definition
mkCmpLe] [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++
(IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> IntType -> FloatType -> Definition
mkFPConvIF [Char]
"sitofp") [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound] [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++
(IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> IntType -> FloatType -> Definition
mkFPConvUF [Char]
"uitofp") [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound] [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++
(IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition)
-> (FloatType -> IntType -> Definition)
-> IntType
-> FloatType
-> Definition
forall a b. (a -> b) -> a -> b
$ [Char] -> FloatType -> IntType -> Definition
mkFPConvFI [Char]
"fptosi") [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound] [FloatType -> Definition]
-> [FloatType -> Definition] -> [FloatType -> Definition]
forall a. [a] -> [a] -> [a]
++
(IntType -> FloatType -> Definition)
-> [IntType] -> [FloatType -> Definition]
forall a b. (a -> b) -> [a] -> [b]
map ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((FloatType -> IntType -> Definition)
-> IntType -> FloatType -> Definition)
-> (FloatType -> IntType -> Definition)
-> IntType
-> FloatType
-> Definition
forall a b. (a -> b) -> a -> b
$ [Char] -> FloatType -> IntType -> Definition
mkFPConvFU [Char]
"fptoui") [IntType
forall a. Bounded a => a
minBound..IntType
forall a. Bounded a => a
maxBound]
mkFDiv :: FloatType -> Definition
mkFDiv = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
"fdiv" [C.cexp|x / y|]
mkFAdd :: FloatType -> Definition
mkFAdd = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
"fadd" [C.cexp|x + y|]
mkFSub :: FloatType -> Definition
mkFSub = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
"fsub" [C.cexp|x - y|]
mkFMul :: FloatType -> Definition
mkFMul = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
"fmul" [C.cexp|x * y|]
mkFMin :: FloatType -> Definition
mkFMin = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
"fmin" [C.cexp|fmin(x, y)|]
mkFMax :: FloatType -> Definition
mkFMax = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
"fmax" [C.cexp|fmax(x, y)|]
mkCmpLt :: FloatType -> Definition
mkCmpLt = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
floatCmpOp [Char]
"cmplt" [C.cexp|x < y|]
mkCmpLe :: FloatType -> Definition
mkCmpLe = [Char] -> Exp -> FloatType -> Definition
forall a. ToExp a => [Char] -> a -> FloatType -> Definition
floatCmpOp [Char]
"cmple" [C.cexp|x <= y|]
mkPow :: FloatType -> Definition
mkPow FloatType
Float32 =
[C.cedecl|static inline float fpow32(float x, float y) { return pow(x, y); }|]
mkPow FloatType
Float64 =
[C.cedecl|static inline double fpow64(double x, double y) { return pow(x, y); }|]
mkFPConv :: (a -> Type) -> (a -> Type) -> [Char] -> a -> a -> Definition
mkFPConv a -> Type
from_f a -> Type
to_f [Char]
s a
from_t a
to_t =
[C.cedecl|static inline $ty:to_ct
$id:(convOp s from_t to_t)($ty:from_ct x) { return ($ty:to_ct)x;} |]
where from_ct :: Type
from_ct = a -> Type
from_f a
from_t
to_ct :: Type
to_ct = a -> Type
to_f a
to_t
mkFPConvFF :: [Char] -> FloatType -> FloatType -> Definition
mkFPConvFF = (FloatType -> Type)
-> (FloatType -> Type)
-> [Char]
-> FloatType
-> FloatType
-> Definition
forall a a.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> [Char] -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType FloatType -> Type
floatTypeToCType
mkFPConvFI :: [Char] -> FloatType -> IntType -> Definition
mkFPConvFI = (FloatType -> Type)
-> (IntType -> Type)
-> [Char]
-> FloatType
-> IntType
-> Definition
forall a a.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> [Char] -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType IntType -> Type
intTypeToCType
mkFPConvIF :: [Char] -> IntType -> FloatType -> Definition
mkFPConvIF = (IntType -> Type)
-> (FloatType -> Type)
-> [Char]
-> IntType
-> FloatType
-> Definition
forall a a.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> [Char] -> a -> a -> Definition
mkFPConv IntType -> Type
intTypeToCType FloatType -> Type
floatTypeToCType
mkFPConvFU :: [Char] -> FloatType -> IntType -> Definition
mkFPConvFU = (FloatType -> Type)
-> (IntType -> Type)
-> [Char]
-> FloatType
-> IntType
-> Definition
forall a a.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> [Char] -> a -> a -> Definition
mkFPConv FloatType -> Type
floatTypeToCType IntType -> Type
uintTypeToCType
mkFPConvUF :: [Char] -> IntType -> FloatType -> Definition
mkFPConvUF = (IntType -> Type)
-> (FloatType -> Type)
-> [Char]
-> IntType
-> FloatType
-> Definition
forall a a.
(Pretty a, Pretty a) =>
(a -> Type) -> (a -> Type) -> [Char] -> a -> a -> Definition
mkFPConv IntType -> Type
uintTypeToCType FloatType -> Type
floatTypeToCType
simpleFloatOp :: [Char] -> a -> FloatType -> Definition
simpleFloatOp [Char]
s a
e FloatType
t =
[C.cedecl|static inline $ty:ct $id:(taggedF s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where ct :: Type
ct = FloatType -> Type
floatTypeToCType FloatType
t
floatCmpOp :: [Char] -> a -> FloatType -> Definition
floatCmpOp [Char]
s a
e FloatType
t =
[C.cedecl|static inline typename bool $id:(taggedF s t)($ty:ct x, $ty:ct y) { return $exp:e; }|]
where ct :: Type
ct = FloatType -> Type
floatTypeToCType FloatType
t
cFloat32Funs :: [C.Definition]
cFloat32Funs :: [Definition]
cFloat32Funs = [C.cunit|
static inline float $id:(funName' "log32")(float x) {
return log(x);
}
static inline float $id:(funName' "log2_32")(float x) {
return log2(x);
}
static inline float $id:(funName' "log10_32")(float x) {
return log10(x);
}
static inline float $id:(funName' "sqrt32")(float x) {
return sqrt(x);
}
static inline float $id:(funName' "exp32")(float x) {
return exp(x);
}
static inline float $id:(funName' "cos32")(float x) {
return cos(x);
}
static inline float $id:(funName' "sin32")(float x) {
return sin(x);
}
static inline float $id:(funName' "tan32")(float x) {
return tan(x);
}
static inline float $id:(funName' "acos32")(float x) {
return acos(x);
}
static inline float $id:(funName' "asin32")(float x) {
return asin(x);
}
static inline float $id:(funName' "atan32")(float x) {
return atan(x);
}
static inline float $id:(funName' "cosh32")(float x) {
return cosh(x);
}
static inline float $id:(funName' "sinh32")(float x) {
return sinh(x);
}
static inline float $id:(funName' "tanh32")(float x) {
return tanh(x);
}
static inline float $id:(funName' "acosh32")(float x) {
return acosh(x);
}
static inline float $id:(funName' "asinh32")(float x) {
return asinh(x);
}
static inline float $id:(funName' "atanh32")(float x) {
return atanh(x);
}
static inline float $id:(funName' "atan2_32")(float x, float y) {
return atan2(x,y);
}
static inline float $id:(funName' "gamma32")(float x) {
return tgamma(x);
}
static inline float $id:(funName' "lgamma32")(float x) {
return lgamma(x);
}
static inline typename bool $id:(funName' "isnan32")(float x) {
return isnan(x);
}
static inline typename bool $id:(funName' "isinf32")(float x) {
return isinf(x);
}
static inline typename int32_t $id:(funName' "to_bits32")(float x) {
union {
float f;
typename int32_t t;
} p;
p.f = x;
return p.t;
}
static inline float $id:(funName' "from_bits32")(typename int32_t x) {
union {
typename int32_t f;
float t;
} p;
p.f = x;
return p.t;
}
$esc:("#ifdef __OPENCL_VERSION__")
static inline float fmod32(float x, float y) {
return fmod(x, y);
}
static inline float $id:(funName' "round32")(float x) {
return rint(x);
}
static inline float $id:(funName' "floor32")(float x) {
return floor(x);
}
static inline float $id:(funName' "ceil32")(float x) {
return ceil(x);
}
static inline float $id:(funName' "lerp32")(float v0, float v1, float t) {
return mix(v0, v1, t);
}
static inline float $id:(funName' "mad32")(float a, float b, float c) {
return mad(a,b,c);
}
static inline float $id:(funName' "fma32")(float a, float b, float c) {
return fma(a,b,c);
}
$esc:("#else")
static inline float fmod32(float x, float y) {
return fmodf(x, y);
}
static inline float $id:(funName' "round32")(float x) {
return rintf(x);
}
static inline float $id:(funName' "floor32")(float x) {
return floorf(x);
}
static inline float $id:(funName' "ceil32")(float x) {
return ceilf(x);
}
static inline float $id:(funName' "lerp32")(float v0, float v1, float t) {
return v0 + (v1-v0)*t;
}
static inline float $id:(funName' "mad32")(float a, float b, float c) {
return a*b+c;
}
static inline float $id:(funName' "fma32")(float a, float b, float c) {
return fmaf(a,b,c);
}
$esc:("#endif")
|]
cFloat64Funs :: [C.Definition]
cFloat64Funs :: [Definition]
cFloat64Funs = [C.cunit|
static inline double $id:(funName' "log64")(double x) {
return log(x);
}
static inline double $id:(funName' "log2_64")(double x) {
return log2(x);
}
static inline double $id:(funName' "log10_64")(double x) {
return log10(x);
}
static inline double $id:(funName' "sqrt64")(double x) {
return sqrt(x);
}
static inline double $id:(funName' "exp64")(double x) {
return exp(x);
}
static inline double $id:(funName' "cos64")(double x) {
return cos(x);
}
static inline double $id:(funName' "sin64")(double x) {
return sin(x);
}
static inline double $id:(funName' "tan64")(double x) {
return tan(x);
}
static inline double $id:(funName' "acos64")(double x) {
return acos(x);
}
static inline double $id:(funName' "asin64")(double x) {
return asin(x);
}
static inline double $id:(funName' "atan64")(double x) {
return atan(x);
}
static inline double $id:(funName' "cosh64")(double x) {
return cosh(x);
}
static inline double $id:(funName' "sinh64")(double x) {
return sinh(x);
}
static inline double $id:(funName' "tanh64")(double x) {
return tanh(x);
}
static inline double $id:(funName' "acosh64")(double x) {
return acosh(x);
}
static inline double $id:(funName' "asinh64")(double x) {
return asinh(x);
}
static inline double $id:(funName' "atanh64")(double x) {
return atanh(x);
}
static inline double $id:(funName' "atan2_64")(double x, double y) {
return atan2(x,y);
}
static inline double $id:(funName' "gamma64")(double x) {
return tgamma(x);
}
static inline double $id:(funName' "lgamma64")(double x) {
return lgamma(x);
}
static inline double $id:(funName' "fma64")(double a, double b, double c) {
return fma(a,b,c);
}
static inline double $id:(funName' "round64")(double x) {
return rint(x);
}
static inline double $id:(funName' "ceil64")(double x) {
return ceil(x);
}
static inline double $id:(funName' "floor64")(double x) {
return floor(x);
}
static inline typename bool $id:(funName' "isnan64")(double x) {
return isnan(x);
}
static inline typename bool $id:(funName' "isinf64")(double x) {
return isinf(x);
}
static inline typename int64_t $id:(funName' "to_bits64")(double x) {
union {
double f;
typename int64_t t;
} p;
p.f = x;
return p.t;
}
static inline double $id:(funName' "from_bits64")(typename int64_t x) {
union {
typename int64_t f;
double t;
} p;
p.f = x;
return p.t;
}
static inline double fmod64(double x, double y) {
return fmod(x, y);
}
$esc:("#ifdef __OPENCL_VERSION__")
static inline double $id:(funName' "lerp64")(double v0, double v1, double t) {
return mix(v0, v1, t);
}
static inline double $id:(funName' "mad64")(double a, double b, double c) {
return mad(a,b,c);
}
$esc:("#else")
static inline double $id:(funName' "lerp64")(double v0, double v1, double t) {
return v0 + (v1-v0)*t;
}
static inline double $id:(funName' "mad64")(double a, double b, double c) {
return a*b+c;
}
$esc:("#endif")
|]