{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
where
import Prelude ( Eq, Num, Either(..), ($), (==), undefined, otherwise, flip, fromInteger )
import Control.Applicative
import Control.Monad
import Data.Bits ( finiteBitSize )
import Data.ByteString.Short ( ShortByteString )
import Data.Monoid
import Data.String
import Foreign.Storable ( sizeOf )
import Text.Printf
import qualified Data.Ord as Ord
import qualified Prelude as P
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Array.Sugar
import LLVM.AST.Type.Constant
import LLVM.AST.Type.Global
import LLVM.AST.Type.Instruction
import LLVM.AST.Type.Instruction.Compare
import LLVM.AST.Type.Name
import LLVM.AST.Type.Operand
import LLVM.AST.Type.Representation
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Type
add :: NumType a -> IR a -> IR a -> CodeGen (IR a)
add = binop Add
sub :: NumType a -> IR a -> IR a -> CodeGen (IR a)
sub = binop Sub
mul :: NumType a -> IR a -> IR a -> CodeGen (IR a)
mul = binop Mul
negate :: NumType a -> IR a -> CodeGen (IR a)
negate t x =
case t of
IntegralNumType i | IntegralDict <- integralDict i -> mul t x (ir t (num t (P.negate 1)))
FloatingNumType f | FloatingDict <- floatingDict f -> mul t x (ir t (num t (P.negate 1)))
abs :: forall a. NumType a -> IR a -> CodeGen (IR a)
abs n x =
case n of
FloatingNumType f -> mathf "fabs" f x
IntegralNumType i
| unsigned i -> return x
| IntegralDict <- integralDict i ->
let p = ScalarPrimType (NumScalarType n)
t = PrimType p
in
case finiteBitSize (undefined :: a) of
64 -> call (Lam p (op n x) (Body t "llabs")) [NoUnwind, ReadNone]
_ -> call (Lam p (op n x) (Body t "abs")) [NoUnwind, ReadNone]
signum :: forall a. NumType a -> IR a -> CodeGen (IR a)
signum t x =
case t of
IntegralNumType i
| IntegralDict <- integralDict i
, unsigned i
-> do z <- neq (NumScalarType t) x (ir t (num t 0))
s <- instr (Ext boundedType (IntegralBoundedType i) (op scalarType z))
return s
| IntegralDict <- integralDict i
-> do let wsib = finiteBitSize (undefined::a)
z <- neq (NumScalarType t) x (ir t (num t 0))
l <- instr (Ext boundedType (IntegralBoundedType i) (op scalarType z))
r <- shiftRA i x (ir integralType (integral integralType (wsib P.- 1)))
s <- bor i l r
return s
FloatingNumType f
| FloatingDict <- floatingDict f
-> do
l <- gt (NumScalarType t) x (ir f (floating f 0))
r <- lt (NumScalarType t) x (ir f (floating f 0))
u <- instr (IntToFP (Right nonNumType) f (op scalarType l))
v <- instr (IntToFP (Right nonNumType) f (op scalarType r))
s <- sub t u v
return s
quot :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
quot = binop Quot
rem :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
rem = binop Rem
quotRem :: IntegralType a -> IR a -> IR a -> CodeGen (IR (a,a))
quotRem t x y = do
q <- quot t x y
r <- rem t x y
return $ pair q r
idiv :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
idiv i x y
| unsigned i
= quot i x y
| IntegralDict <- integralDict i
, EltDict <- integralElt i
, zero <- ir i (integral i 0)
, one <- ir i (integral i 1)
, n <- IntegralNumType i
, s <- NumScalarType n
= if gt s x zero `land` lt s y zero
then do
a <- sub n x one
b <- quot i a y
c <- sub n b one
return c
else
if lt s x zero `land` gt s y zero
then do
a <- add n x one
b <- quot i a y
c <- sub n b one
return c
else
quot i x y
mod :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
mod i x y
| unsigned i
= rem i x y
| IntegralDict <- integralDict i
, EltDict <- integralElt i
, zero <- ir i (integral i 0)
, n <- IntegralNumType i
, s <- NumScalarType n
= do r <- rem i x y
if (gt s x zero `land` lt s y zero) `lor` (lt s x zero `land` gt s y zero)
then if neq s r zero
then add n r y
else return zero
else return r
divMod :: IntegralType a -> IR a -> IR a -> CodeGen (IR (a,a))
divMod i x y
| unsigned i
= quotRem i x y
| IntegralDict <- integralDict i
, EltDict <- integralElt i
, zero <- ir i (integral i 0)
, one <- ir i (integral i 1)
, n <- IntegralNumType i
, s <- NumScalarType n
= if gt s x zero `land` lt s y zero
then do
a <- sub n x one
b <- quotRem i a y
c <- sub n (fst b) one
d <- add n (snd b) y
e <- add n d one
return $ pair c e
else
if lt s x zero `land` gt s y zero
then do
a <- add n x one
b <- quotRem i a y
c <- sub n (fst b) one
d <- add n (snd b) y
e <- sub n d one
return $ pair c e
else
quotRem i x y
band :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
band = binop BAnd
bor :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
bor = binop BOr
xor :: IntegralType a -> IR a -> IR a -> CodeGen (IR a)
xor = binop BXor
complement :: IntegralType a -> IR a -> CodeGen (IR a)
complement t x | IntegralDict <- integralDict t = xor t x (ir t (integral t (P.negate 1)))
shiftL :: IntegralType a -> IR a -> IR Int -> CodeGen (IR a)
shiftL t x i = do
i' <- fromIntegral integralType (IntegralNumType t) i
binop ShiftL t x i'
shiftR :: IntegralType a -> IR a -> IR Int -> CodeGen (IR a)
shiftR t
| signed t = shiftRA t
| otherwise = shiftRL t
shiftRL :: IntegralType a -> IR a -> IR Int -> CodeGen (IR a)
shiftRL t x i = do
i' <- fromIntegral integralType (IntegralNumType t) i
r <- binop ShiftRL t x i'
return r
shiftRA :: IntegralType a -> IR a -> IR Int -> CodeGen (IR a)
shiftRA t x i = do
i' <- fromIntegral integralType (IntegralNumType t) i
r <- binop ShiftRA t x i'
return r
rotateL :: forall a. IntegralType a -> IR a -> IR Int -> CodeGen (IR a)
rotateL t x i
| IntegralDict <- integralDict t
= do let wsib = finiteBitSize (undefined::a)
i1 <- band integralType i (ir integralType (integral integralType (wsib P.- 1)))
i2 <- sub numType (ir numType (integral integralType wsib)) i1
a <- shiftL t x i1
b <- shiftRL t x i2
c <- bor t a b
return c
rotateR :: forall a. IntegralType a -> IR a -> IR Int -> CodeGen (IR a)
rotateR t x i = do
i' <- negate numType i
r <- rotateL t x i'
return r
popCount :: forall a. IntegralType a -> IR a -> CodeGen (IR Int)
popCount i x
| IntegralDict <- integralDict i
= do let ctpop = fromString $ printf "llvm.ctpop.i%d" (finiteBitSize (undefined::a))
p = ScalarPrimType (NumScalarType (IntegralNumType i))
t = PrimType p
c <- call (Lam p (op i x) (Body t ctpop)) [NoUnwind, ReadNone]
r <- fromIntegral i numType c
return r
countLeadingZeros :: forall a. IntegralType a -> IR a -> CodeGen (IR Int)
countLeadingZeros i x
| IntegralDict <- integralDict i
= do let clz = fromString $ printf "llvm.ctlz.i%d" (finiteBitSize (undefined::a))
p = ScalarPrimType (NumScalarType (IntegralNumType i))
t = PrimType p
c <- call (Lam p (op i x) (Lam primType (nonnum nonNumType False) (Body t clz))) [NoUnwind, ReadNone]
r <- fromIntegral i numType c
return r
countTrailingZeros :: forall a. IntegralType a -> IR a -> CodeGen (IR Int)
countTrailingZeros i x
| IntegralDict <- integralDict i
= do let clz = fromString $ printf "llvm.cttz.i%d" (finiteBitSize (undefined::a))
p = ScalarPrimType (NumScalarType (IntegralNumType i))
t = PrimType p
c <- call (Lam p (op i x) (Lam primType (nonnum nonNumType False) (Body t clz))) [NoUnwind, ReadNone]
r <- fromIntegral i numType c
return r
fdiv :: FloatingType a -> IR a -> IR a -> CodeGen (IR a)
fdiv = binop Div
recip :: FloatingType a -> IR a -> CodeGen (IR a)
recip t x | FloatingDict <- floatingDict t = fdiv t (ir t (floating t 1)) x
sin :: FloatingType a -> IR a -> CodeGen (IR a)
sin = mathf "sin"
cos :: FloatingType a -> IR a -> CodeGen (IR a)
cos = mathf "cos"
tan :: FloatingType a -> IR a -> CodeGen (IR a)
tan = mathf "tan"
sinh :: FloatingType a -> IR a -> CodeGen (IR a)
sinh = mathf "sinh"
cosh :: FloatingType a -> IR a -> CodeGen (IR a)
cosh = mathf "cosh"
tanh :: FloatingType a -> IR a -> CodeGen (IR a)
tanh = mathf "tanh"
asin :: FloatingType a -> IR a -> CodeGen (IR a)
asin = mathf "asin"
acos :: FloatingType a -> IR a -> CodeGen (IR a)
acos = mathf "acos"
atan :: FloatingType a -> IR a -> CodeGen (IR a)
atan = mathf "atan"
asinh :: FloatingType a -> IR a -> CodeGen (IR a)
asinh = mathf "asinh"
acosh :: FloatingType a -> IR a -> CodeGen (IR a)
acosh = mathf "acosh"
atanh :: FloatingType a -> IR a -> CodeGen (IR a)
atanh = mathf "atanh"
atan2 :: FloatingType a -> IR a -> IR a -> CodeGen (IR a)
atan2 = mathf2 "atan2"
exp :: FloatingType a -> IR a -> CodeGen (IR a)
exp = mathf "exp"
fpow :: FloatingType a -> IR a -> IR a -> CodeGen (IR a)
fpow = mathf2 "pow"
sqrt :: FloatingType a -> IR a -> CodeGen (IR a)
sqrt = mathf "sqrt"
log :: FloatingType a -> IR a -> CodeGen (IR a)
log = mathf "log"
logBase :: forall a. FloatingType a -> IR a -> IR a -> CodeGen (IR a)
logBase t x@(op t -> base) y | FloatingDict <- floatingDict t = logBase'
where
match :: Eq t => Operand t -> Operand t -> Bool
match (ConstantOperand (ScalarConstant _ u))
(ConstantOperand (ScalarConstant _ v)) = u == v
match _ _ = False
logBase' :: (Num a, Eq a) => CodeGen (IR a)
logBase' | match base (floating t 2) = mathf "log2" t y
| match base (floating t 10) = mathf "log10" t y
| otherwise
= do x' <- log t x
y' <- log t y
fdiv t y' x'
isNaN :: FloatingType a -> IR a -> CodeGen (IR Bool)
isNaN f (op f -> x) = do
let p = ScalarPrimType (NumScalarType (FloatingNumType f))
t = type'
name <- intrinsic
$ case f of
TypeFloat{} -> "isnanf"
TypeCFloat{} -> "isnanf"
TypeDouble{} -> "isnand"
TypeCDouble{} -> "isnand"
r <- call (Lam p x (Body t name)) [NoUnwind, ReadOnly]
return r
isInfinite :: FloatingType a -> IR a -> CodeGen (IR Bool)
isInfinite f (op f -> x) = do
let p = ScalarPrimType (NumScalarType (FloatingNumType f))
t = type'
name <- intrinsic
$ case f of
TypeFloat{} -> "isinff"
TypeCFloat{} -> "isinff"
TypeDouble{} -> "isinfd"
TypeCDouble{} -> "isinfd"
r <- call (Lam p x (Body t name)) [NoUnwind, ReadOnly]
return r
truncate :: FloatingType a -> IntegralType b -> IR a -> CodeGen (IR b)
truncate tf ti (op tf -> x) = instr (FPToInt tf ti x)
round :: FloatingType a -> IntegralType b -> IR a -> CodeGen (IR b)
round tf ti x = do
i <- mathf "round" tf x
truncate tf ti i
floor :: FloatingType a -> IntegralType b -> IR a -> CodeGen (IR b)
floor tf ti x = do
i <- mathf "floor" tf x
truncate tf ti i
ceiling :: FloatingType a -> IntegralType b -> IR a -> CodeGen (IR b)
ceiling tf ti x = do
i <- mathf "ceil" tf x
truncate tf ti i
cmp :: Ordering -> ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
cmp p dict (op dict -> x) (op dict -> y) = instr (Cmp dict p x y)
lt :: ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
lt = cmp LT
gt :: ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
gt = cmp GT
lte :: ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
lte = cmp LE
gte :: ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
gte = cmp GE
eq :: ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
eq = cmp EQ
neq :: ScalarType a -> IR a -> IR a -> CodeGen (IR Bool)
neq = cmp NE
max :: ScalarType a -> IR a -> IR a -> CodeGen (IR a)
max ty x y
| NumScalarType (FloatingNumType f) <- ty = mathf2 "fmax" f x y
| otherwise = do c <- op scalarType <$> gte ty x y
binop (flip Select c) ty x y
min :: ScalarType a -> IR a -> IR a -> CodeGen (IR a)
min ty x y
| NumScalarType (FloatingNumType f) <- ty = mathf2 "fmin" f x y
| otherwise = do c <- op scalarType <$> lte ty x y
binop (flip Select c) ty x y
land :: CodeGen (IR Bool) -> CodeGen (IR Bool) -> CodeGen (IR Bool)
land x y =
if x
then y
else return $ ir scalarType (scalar scalarType False)
lor :: CodeGen (IR Bool) -> CodeGen (IR Bool) -> CodeGen (IR Bool)
lor x y =
if x
then return $ ir scalarType (scalar scalarType True)
else y
land' :: IR Bool -> IR Bool -> CodeGen (IR Bool)
land' (op scalarType -> x) (op scalarType -> y)
= instr (LAnd x y)
lor' :: IR Bool -> IR Bool -> CodeGen (IR Bool)
lor' (op scalarType -> x) (op scalarType -> y)
= instr (LOr x y)
lnot :: IR Bool -> CodeGen (IR Bool)
lnot (op scalarType -> x) = instr (LNot x)
ord :: IR Char -> CodeGen (IR Int)
ord (op scalarType -> x) =
case finiteBitSize (undefined :: Int) of
32 -> instr (BitCast scalarType x)
64 -> instr (Trunc boundedType boundedType x)
_ -> $internalError "ord" "I don't know what architecture I am"
chr :: IR Int -> CodeGen (IR Char)
chr (op integralType -> x) =
case finiteBitSize (undefined :: Int) of
32 -> instr (BitCast scalarType x)
64 -> instr (Ext boundedType boundedType x)
_ -> $internalError "chr" "I don't know what architecture I am"
boolToInt :: IR Bool -> CodeGen (IR Int)
boolToInt x = instr (Ext boundedType boundedType (op scalarType x))
fromIntegral :: forall a b. IntegralType a -> NumType b -> IR a -> CodeGen (IR b)
fromIntegral i1 n (op i1 -> x) =
case n of
FloatingNumType f
-> instr (IntToFP (Left i1) f x)
IntegralNumType (i2 :: IntegralType b)
| IntegralDict <- integralDict i1
, IntegralDict <- integralDict i2
-> let
bits_a = finiteBitSize (undefined::a)
bits_b = finiteBitSize (undefined::b)
in
case Ord.compare bits_a bits_b of
Ord.EQ -> instr (BitCast (NumScalarType n) x)
Ord.GT -> instr (Trunc (IntegralBoundedType i1) (IntegralBoundedType i2) x)
Ord.LT -> instr (Ext (IntegralBoundedType i1) (IntegralBoundedType i2) x)
toFloating :: forall a b. NumType a -> FloatingType b -> IR a -> CodeGen (IR b)
toFloating n1 f2 (op n1 -> x) =
case n1 of
IntegralNumType i1
-> instr (IntToFP (Left i1) f2 x)
FloatingNumType (f1 :: FloatingType a)
| FloatingDict <- floatingDict f1
, FloatingDict <- floatingDict f2
-> let
bytes_a = sizeOf (undefined::a)
bytes_b = sizeOf (undefined::b)
in
case Ord.compare bytes_a bytes_b of
Ord.EQ -> instr (BitCast (NumScalarType (FloatingNumType f2)) x)
Ord.GT -> instr (FTrunc f1 f2 x)
Ord.LT -> instr (FExt f1 f2 x)
coerce :: forall a b. ScalarType a -> ScalarType b -> IR a -> CodeGen (IR b)
coerce ta tb (op ta -> x) = instr (BitCast tb x)
fst :: IR (a, b) -> IR a
fst (IR (OP_Pair (OP_Pair OP_Unit x) _)) = IR x
snd :: IR (a, b) -> IR b
snd (IR (OP_Pair _ y)) = IR y
pair :: IR a -> IR b -> IR (a, b)
pair (IR x) (IR y) = IR $ OP_Pair (OP_Pair OP_Unit x) y
unpair :: IR (a, b) -> (IR a, IR b)
unpair x = (fst x, snd x)
uncurry :: (IR a -> IR b -> c) -> IR (a, b) -> c
uncurry f (unpair -> (x,y)) = f x y
binop :: IROP dict => (dict a -> Operand a -> Operand a -> Instruction a) -> dict a -> IR a -> IR a -> CodeGen (IR a)
binop f dict (op dict -> x) (op dict -> y) = instr (f dict x y)
fst3 :: IR (a, b, c) -> IR a
fst3 (IR (OP_Pair (OP_Pair (OP_Pair OP_Unit x) _) _)) = IR x
snd3 :: IR (a, b, c) -> IR b
snd3 (IR (OP_Pair (OP_Pair _ y) _)) = IR y
thd3 :: IR (a, b, c) -> IR c
thd3 (IR (OP_Pair _ z)) = IR z
trip :: IR a -> IR b -> IR c -> IR (a, b, c)
trip (IR x) (IR y) (IR z) = IR $ OP_Pair (OP_Pair (OP_Pair OP_Unit x) y) z
untrip :: IR (a, b, c) -> (IR a, IR b, IR c)
untrip t = (fst3 t, snd3 t, thd3 t)
{-# INLINABLE lift #-}
lift :: IsScalar a => a -> IR a
lift x = ir scalarType (scalar scalarType x)
ifThenElse
:: Elt a
=> CodeGen (IR Bool)
-> CodeGen (IR a)
-> CodeGen (IR a)
-> CodeGen (IR a)
ifThenElse test yes no = do
ifThen <- newBlock "if.then"
ifElse <- newBlock "if.else"
ifExit <- newBlock "if.exit"
_ <- beginBlock "if.entry"
p <- test
_ <- cbr p ifThen ifElse
setBlock ifThen
tv <- yes
tb <- br ifExit
setBlock ifElse
fv <- no
fb <- br ifExit
setBlock ifExit
phi [(tv, tb), (fv, fb)]
when :: CodeGen (IR Bool) -> CodeGen () -> CodeGen ()
when test doit = do
body <- newBlock "when.body"
exit <- newBlock "when.exit"
p <- test
_ <- cbr p body exit
setBlock body
doit
_ <- br exit
setBlock exit
unless :: CodeGen (IR Bool) -> CodeGen () -> CodeGen ()
unless test doit = do
body <- newBlock "unless.body"
exit <- newBlock "unless.exit"
p <- test
_ <- cbr p exit body
setBlock body
doit
_ <- br exit
setBlock exit
mathf :: ShortByteString -> FloatingType t -> IR t -> CodeGen (IR t)
mathf n f (op f -> x) = do
let s = ScalarPrimType (NumScalarType (FloatingNumType f))
t = PrimType s
name <- lm f n
r <- call (Lam s x (Body t name)) [NoUnwind, ReadOnly]
return r
mathf2 :: ShortByteString -> FloatingType t -> IR t -> IR t -> CodeGen (IR t)
mathf2 n f (op f -> x) (op f -> y) = do
let s = ScalarPrimType (NumScalarType (FloatingNumType f))
t = PrimType s
name <- lm f n
r <- call (Lam s x (Lam s y (Body t name))) [NoUnwind, ReadOnly]
return r
lm :: FloatingType t -> ShortByteString -> CodeGen Label
lm t n
= intrinsic
$ case t of
TypeFloat{} -> n<>"f"
TypeCFloat{} -> n<>"f"
TypeDouble{} -> n
TypeCDouble{} -> n