module Data.Array.Accelerate.CUDA.CodeGen (
CUTranslSkel, codegenAcc,
) where
import Prelude hiding ( id, exp, replicate )
import Control.Applicative ( (<$>), (<*>) )
import Control.Monad.State.Strict
import Data.Loc
import Data.Char
import Data.HashSet ( HashSet )
import Foreign.CUDA.Analysis
import Language.C.Quote.CUDA
import qualified Language.C as C
import qualified Data.HashSet as Set
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Tuple
import Data.Array.Accelerate.Trafo
import Data.Array.Accelerate.Pretty ()
import Data.Array.Accelerate.Analysis.Shape
import Data.Array.Accelerate.Array.Sugar ( Array, Shape, Elt, EltRepr )
import Data.Array.Accelerate.Array.Representation ( SliceIndex(..) )
import qualified Data.Array.Accelerate.Array.Sugar as Sugar
import qualified Data.Array.Accelerate.Analysis.Type as Sugar
import Data.Array.Accelerate.CUDA.AST hiding ( Val(..), prj )
import Data.Array.Accelerate.CUDA.CodeGen.Base
import Data.Array.Accelerate.CUDA.CodeGen.Type
import Data.Array.Accelerate.CUDA.CodeGen.Monad
import Data.Array.Accelerate.CUDA.CodeGen.Mapping
import Data.Array.Accelerate.CUDA.CodeGen.IndexSpace
import Data.Array.Accelerate.CUDA.CodeGen.PrefixSum
import Data.Array.Accelerate.CUDA.CodeGen.Reduction
import Data.Array.Accelerate.CUDA.CodeGen.Stencil
import Data.Array.Accelerate.CUDA.Foreign.Import ( canExecuteExp )
data Val env where
Empty :: Val ()
Push :: Val env -> [C.Exp] -> Val (env, s)
prj :: Idx env t -> Val env -> [C.Exp]
prj ZeroIdx (Push _ v) = v
prj (SuccIdx ix) (Push val _) = prj ix val
prj _ _ = $internalError "prj" "inconsistent valuation"
codegenAcc :: forall aenv arrs. DeviceProperties -> DelayedOpenAcc aenv arrs -> Gamma aenv -> [ CUTranslSkel aenv arrs ]
codegenAcc _ Delayed{} _ = $internalError "codegenAcc" "expected manifest array"
codegenAcc dev (Manifest pacc) aenv
= codegen
$ case pacc of
Map f a -> mkMap dev aenv <$> travF1 f <*> travD a
Generate _ f -> mkGenerate dev aenv <$> travF1 f
Transform _ p f a -> mkTransform dev aenv <$> travF1 p <*> travF1 f <*> travD a
Backpermute _ p a -> mkTransform dev aenv <$> travF1 p <*> travF1 id <*> travD a
Fold f z a -> mkFold dev aenv <$> travF2 f <*> travE z <*> travD a
Fold1 f a -> mkFold1 dev aenv <$> travF2 f <*> travD a
FoldSeg f z a s -> mkFoldSeg dev aenv <$> travF2 f <*> travE z <*> travD a <*> travD s
Fold1Seg f a s -> mkFold1Seg dev aenv <$> travF2 f <*> travD a <*> travD s
Scanl f z a -> mkScanl dev aenv <$> travF2 f <*> travE z <*> travD a
Scanr f z a -> mkScanr dev aenv <$> travF2 f <*> travE z <*> travD a
Scanl' f z a -> mkScanl' dev aenv <$> travF2 f <*> travE z <*> travD a
Scanr' f z a -> mkScanr' dev aenv <$> travF2 f <*> travE z <*> travD a
Scanl1 f a -> mkScanl1 dev aenv <$> travF2 f <*> travD a
Scanr1 f a -> mkScanr1 dev aenv <$> travF2 f <*> travD a
Permute f _ p a -> mkPermute dev aenv <$> travF2 f <*> travF1 p <*> travD a
Stencil f b a -> mkStencil dev aenv <$> travF1 f <*> travB a b
Stencil2 f b1 a1 b2 a2 -> mkStencil2 dev aenv <$> travF2 f <*> travB a1 b1 <*> travB a2 b2
Alet{} -> unexpectedError
Avar{} -> unexpectedError
Apply{} -> unexpectedError
Acond{} -> unexpectedError
Awhile{} -> unexpectedError
Atuple{} -> unexpectedError
Aprj{} -> unexpectedError
Use{} -> unexpectedError
Unit{} -> unexpectedError
Aforeign{} -> unexpectedError
Reshape{} -> unexpectedError
Replicate{} -> fusionError
Slice{} -> fusionError
ZipWith{} -> fusionError
where
codegen :: CUDA [CUTranslSkel aenv a] -> [CUTranslSkel aenv a]
codegen cuda =
let (skeletons, st) = runCUDA cuda
addTo (CUTranslSkel name code) =
CUTranslSkel name (Set.foldr (\h c -> [cedecl| $esc:("#include \"" ++ h ++ "\"") |] : c) code (headers st))
in
map addTo skeletons
id :: Elt a => DelayedFun aenv (a -> a)
id = Lam (Body (Var ZeroIdx))
travD :: (Shape sh, Elt e) => DelayedOpenAcc aenv (Array sh e) -> CUDA (CUDelayedAcc aenv sh e)
travD Manifest{} = $internalError "codegenAcc" "expected delayed array"
travD Delayed{..} = CUDelayed <$> travE extentD
<*> travF1 indexD
<*> travF1 linearIndexD
travF1 :: DelayedFun aenv (a -> b) -> CUDA (CUFun1 aenv (a -> b))
travF1 = codegenFun1 dev aenv
travF2 :: DelayedFun aenv (a -> b -> c) -> CUDA (CUFun2 aenv (a -> b -> c))
travF2 = codegenFun2 dev aenv
travE :: DelayedExp aenv t -> CUDA (CUExp aenv t)
travE = codegenExp dev aenv
travB :: forall sh e. Elt e
=> DelayedOpenAcc aenv (Array sh e) -> Boundary (EltRepr e) -> CUDA (Boundary (CUExp aenv e))
travB _ Clamp = return Clamp
travB _ Mirror = return Mirror
travB _ Wrap = return Wrap
travB _ (Constant c) = return . Constant $ CUExp ([], codegenConst (Sugar.eltType (undefined::e)) c)
prim :: String
prim = showPreAccOp pacc
unexpectedError = $internalError "codegenAcc" $ "unexpected array primitive: " ++ prim
fusionError = $internalError "codegenAcc" $ "unexpected fusible material: " ++ prim
codegenFun1
:: forall aenv a b. DeviceProperties
-> Gamma aenv
-> DelayedFun aenv (a -> b)
-> CUDA (CUFun1 aenv (a -> b))
codegenFun1 dev aenv fun
| Lam (Body f) <- fun
= let
go :: Rvalue x => [x] -> Gen ([C.BlockItem], [C.Exp])
go x = do
code <- mapM use =<< codegenOpenExp dev aenv f (Empty `Push` map rvalue x)
env' <- getEnv
return (env', code)
(_,u,_) = locals "undefined_x" (undefined :: a)
in do
n <- get
ExpST _ used <- execCGM (go u)
return $ CUFun1 (mark used u)
$ \xs -> evalState (evalCGM (go xs)) n
| otherwise
= $internalError "codegenFun1" "expected unary function"
codegenFun2
:: forall aenv a b c. DeviceProperties
-> Gamma aenv
-> DelayedFun aenv (a -> b -> c)
-> CUDA (CUFun2 aenv (a -> b -> c))
codegenFun2 dev aenv fun
| Lam (Lam (Body f)) <- fun
= let
go :: (Rvalue x, Rvalue y) => [x] -> [y] -> Gen ([C.BlockItem], [C.Exp])
go x y = do
code <- mapM use =<< codegenOpenExp dev aenv f (Empty `Push` map rvalue x `Push` map rvalue y)
env' <- getEnv
return (env', code)
(_,u,_) = locals "undefined_x" (undefined :: a)
(_,v,_) = locals "undefined_y" (undefined :: b)
in do
n <- get
ExpST _ used <- execCGM (go u v)
return $ CUFun2 (mark used u) (mark used v)
$ \xs ys -> evalState (evalCGM (go xs ys)) n
| otherwise
= $internalError "codegenFun2" "expected binary function"
mark :: HashSet C.Exp -> [C.Exp] -> ([a] -> [(Bool,a)])
mark used xs
= let flags = map (\x -> x `Set.member` used) xs
in zipWith (,) flags
visit :: [C.Exp] -> Gen [C.Exp]
visit exp
| [x] <- exp = use x >> return exp
| otherwise = return exp
codegenExp :: DeviceProperties -> Gamma aenv -> DelayedExp aenv t -> CUDA (CUExp aenv t)
codegenExp dev aenv exp =
evalCGM $ do
code <- codegenOpenExp dev aenv exp Empty
env <- getEnv
return $! CUExp (env,code)
codegenOpenExp
:: forall aenv env' t'. DeviceProperties
-> Gamma aenv
-> DelayedOpenExp env' aenv t'
-> Val env'
-> Gen [C.Exp]
codegenOpenExp dev aenv = cvtE
where
cvtE :: forall env t. DelayedOpenExp env aenv t -> Val env -> Gen [C.Exp]
cvtE exp env = visit =<<
case exp of
Let bnd body -> elet bnd body env
Var ix -> return $ prj ix env
PrimConst c -> return $ [codegenPrimConst c]
Const c -> return $ codegenConst (Sugar.eltType (undefined::t)) c
PrimApp f x -> return <$> primApp f x env
Tuple t -> cvtT t env
Prj i t -> prjT i t exp env
Cond p t e -> cond p t e env
While p f x -> while p f x env
IndexNil -> return []
IndexAny -> return []
IndexCons sh sz -> (++) <$> cvtE sh env <*> cvtE sz env
IndexHead ix -> return . cindexHead <$> cvtE ix env
IndexTail ix -> cindexTail <$> cvtE ix env
IndexSlice ix slix sh -> indexSlice ix slix sh env
IndexFull ix slix sl -> indexFull ix slix sl env
ToIndex sh ix -> toIndex sh ix env
FromIndex sh ix -> fromIndex sh ix env
Index acc ix -> index acc ix env
LinearIndex acc ix -> linearIndex acc ix env
Shape acc -> shape acc env
ShapeSize sh -> shapeSize sh env
Intersect sh1 sh2 -> intersect sh1 sh2 env
Foreign ff _ e -> foreignE ff e env
elet :: DelayedOpenExp env aenv bnd -> DelayedOpenExp (env, bnd) aenv body -> Val env -> Gen [C.Exp]
elet bnd body env = do
bnd' <- cvtE bnd env >>= pushEnv bnd
body' <- cvtE body (env `Push` bnd')
return body'
primApp :: PrimFun (a -> b) -> DelayedOpenExp env aenv a -> Val env -> Gen C.Exp
primApp f x env
| Tuple (NilTup `SnocTup` a `SnocTup` b) <- x
= codegenPrim2 f <$> cvtE' a env <*> cvtE' b env
| otherwise
= codegenPrim1 f <$> cvtE' x env
where
cvtE' :: DelayedOpenExp env aenv a -> Val env -> Gen C.Exp
cvtE' e env = do
(b,r) <- clean $ single "primApp" <$> cvtE e env
if null b
then return r
else return [cexp| ({ $items:b; $exp:r; }) |]
cvtT :: Tuple (DelayedOpenExp env aenv) t -> Val env -> Gen [C.Exp]
cvtT tup env =
case tup of
NilTup -> return []
SnocTup t e -> (++) <$> cvtT t env <*> cvtE e env
prjT :: forall env t e. TupleIdx (TupleRepr t) e
-> DelayedOpenExp env aenv t
-> DelayedOpenExp env aenv e
-> Val env
-> Gen [C.Exp]
prjT ix t e env =
let subset = reverse
. take (length $ expType e)
. drop (prjToInt ix $ Sugar.preExpType Sugar.delayedAccType t)
. reverse
in
subset <$> cvtE t env
prjToInt :: TupleIdx t e -> TupleType a -> Int
prjToInt ZeroTupIdx _ = 0
prjToInt (SuccTupIdx i) (b `PairTuple` a) = sizeTupleType a + prjToInt i b
prjToInt _ _ = $internalError "prjToInt" "inconsistent valuation"
sizeTupleType :: TupleType a -> Int
sizeTupleType UnitTuple = 0
sizeTupleType (SingleTuple _) = 1
sizeTupleType (PairTuple a b) = sizeTupleType a + sizeTupleType b
cond :: forall env t. Elt t
=> DelayedOpenExp env aenv Bool
-> DelayedOpenExp env aenv t
-> DelayedOpenExp env aenv t
-> Val env -> Gen [C.Exp]
cond p t f env = do
p' <- cvtE p env
ok <- single "Cond" <$> pushEnv p p'
ifTrue <- clean $ cvtE t env
ifFalse <- clean $ cvtE f env
var_r <- lift fresh
let (_, r, declr) = locals ('l':var_r) (undefined :: t)
branch = [citem| if ( $exp:ok ) {
$items:(r .=. ifTrue)
}
else {
$items:(r .=. ifFalse)
} |]
: map C.BlockDecl declr
modify (\s -> s { localBindings = branch ++ localBindings s })
return r
while :: forall env a. Elt a
=> DelayedOpenFun env aenv (a -> Bool)
-> DelayedOpenFun env aenv (a -> a)
-> DelayedOpenExp env aenv a
-> Val env
-> Gen [C.Exp]
while test step x env
| Lam (Body p) <- test
, Lam (Body f) <- step
= do
x' <- cvtE x env
var_acc <- lift fresh
var_ok <- lift fresh
var_tmp <- lift fresh
let (_, acc, decl_acc) = locals ('l':var_acc) (undefined :: a)
(_, ok, decl_ok) = locals ('l':var_ok) (undefined :: Bool)
(tmp, _, _) = locals ('l':var_tmp) (undefined :: a)
p' <- clean $ cvtE p (env `Push` acc)
f' <- clean $ cvtE f (env `Push` acc)
let loop = [citem| while ( $exp:(single "while" ok) ) {
$items:(tmp .=. f')
$items:(acc .=. tmp)
$items:(ok .=. p')
} |]
: reverse (ok .=. p')
++ reverse (acc .=. x')
++ map C.BlockDecl decl_ok
++ map C.BlockDecl decl_acc
modify (\s -> s { localBindings = loop ++ localBindings s })
return acc
| otherwise
= error "Would you say we'd be venturing into a zone of danger?"
indexSlice :: SliceIndex (EltRepr slix) sl co (EltRepr sh)
-> DelayedOpenExp env aenv slix
-> DelayedOpenExp env aenv sh
-> Val env
-> Gen [C.Exp]
indexSlice sliceIndex slix sh env =
let restrict :: SliceIndex slix sl co sh -> [C.Exp] -> [C.Exp] -> [C.Exp]
restrict SliceNil _ _ = []
restrict (SliceAll sliceIdx) slx (sz:sl) = sz : restrict sliceIdx slx sl
restrict (SliceFixed sliceIdx) (_:slx) ( _:sl) = restrict sliceIdx slx sl
restrict _ _ _ = $internalError "IndexSlice" "unexpected shapes"
slice slix' sh' = reverse $ restrict sliceIndex (reverse slix') (reverse sh')
in
slice <$> cvtE slix env <*> cvtE sh env
indexFull :: SliceIndex (EltRepr slix) (EltRepr sl) co sh
-> DelayedOpenExp env aenv slix
-> DelayedOpenExp env aenv sl
-> Val env
-> Gen [C.Exp]
indexFull sliceIndex slix sl env =
let extend :: SliceIndex slix sl co sh -> [C.Exp] -> [C.Exp] -> [C.Exp]
extend SliceNil _ _ = []
extend (SliceAll sliceIdx) slx (sz:sh) = sz : extend sliceIdx slx sh
extend (SliceFixed sliceIdx) (sz:slx) sh = sz : extend sliceIdx slx sh
extend _ _ _ = $internalError "IndexFull" "unexpected shapes"
replicate slix' sl' = reverse $ extend sliceIndex (reverse slix') (reverse sl')
in
replicate <$> cvtE slix env <*> cvtE sl env
toIndex :: DelayedOpenExp env aenv sh -> DelayedOpenExp env aenv sh -> Val env -> Gen [C.Exp]
toIndex sh ix env = do
sh' <- mapM use =<< cvtE sh env
ix' <- mapM use =<< cvtE ix env
return [ ctoIndex sh' ix' ]
fromIndex :: DelayedOpenExp env aenv sh -> DelayedOpenExp env aenv Int -> Val env -> Gen [C.Exp]
fromIndex sh ix env = do
sh' <- mapM use =<< cvtE sh env
ix' <- cvtE ix env
tmp <- lift fresh
let (ls, sz) = cfromIndex sh' (single "fromIndex" ix') tmp
modify (\st -> st { localBindings = reverse ls ++ localBindings st })
return sz
index :: (Shape sh, Elt e)
=> DelayedOpenAcc aenv (Array sh e)
-> DelayedOpenExp env aenv sh
-> Val env
-> Gen [C.Exp]
index acc ix env
| Manifest (Avar idx) <- acc
= let (sh, arr) = namesOfAvar aenv idx
ty = accType acc
in do
ix' <- mapM use =<< cvtE ix env
i <- bind cint $ ctoIndex (cshape (expDim ix) sh) ix'
return $ zipWith (\t a -> indexArray dev t (cvar a) i) ty arr
| otherwise
= $internalError "Index" "expected array variable"
linearIndex :: (Shape sh, Elt e)
=> DelayedOpenAcc aenv (Array sh e)
-> DelayedOpenExp env aenv Int
-> Val env
-> Gen [C.Exp]
linearIndex acc ix env
| Manifest (Avar idx) <- acc
= let (_, arr) = namesOfAvar aenv idx
ty = accType acc
in do
ix' <- mapM use =<< cvtE ix env
i <- bind [cty| int |] $ single "LinearIndex" ix'
return $ zipWith (\t a -> indexArray dev t (cvar a) i) ty arr
| otherwise
= $internalError "LinearIndex" "expected array variable"
shape :: (Shape sh, Elt e) => DelayedOpenAcc aenv (Array sh e) -> Val env -> Gen [C.Exp]
shape acc _env
| Manifest (Avar idx) <- acc
= return $ cshape (delayedDim acc) (fst (namesOfAvar aenv idx))
| otherwise
= $internalError "Shape" "expected array variable"
shapeSize :: DelayedOpenExp env aenv sh -> Val env -> Gen [C.Exp]
shapeSize sh env = return . csize <$> cvtE sh env
intersect :: forall env sh. Elt sh
=> DelayedOpenExp env aenv sh
-> DelayedOpenExp env aenv sh
-> Val env -> Gen [C.Exp]
intersect sh1 sh2 env =
zipWith (\a b -> ccall "min" [a,b]) <$> cvtE sh1 env <*> cvtE sh2 env
foreignE :: forall f a b env. (Sugar.Foreign f, Elt a, Elt b)
=> f a b
-> DelayedOpenExp env aenv a
-> Val env
-> Gen [C.Exp]
foreignE ff x env = case canExecuteExp ff of
Nothing -> $internalError "codegenOpenExp" "Non-CUDA foreign expression encountered"
Just (hs, f) -> do
lift $ modify (\st -> st { headers = foldl (flip Set.insert) (headers st) hs })
args <- cvtE x env
mapM_ use args
return $ [ccall f (ccastTup (Sugar.eltType (undefined::a)) args)]
clean :: Gen a -> Gen ([C.BlockItem], a)
clean this = do
env <- state (\s -> ( localBindings s, s { localBindings = [] } ))
r <- this
env' <- state (\s -> ( localBindings s, s { localBindings = env } ))
return (reverse env', r)
single :: String -> [C.Exp] -> C.Exp
single _ [x] = x
single loc _ = $internalError loc "expected single expression"
codegenPrimConst :: PrimConst a -> C.Exp
codegenPrimConst (PrimMinBound ty) = codegenMinBound ty
codegenPrimConst (PrimMaxBound ty) = codegenMaxBound ty
codegenPrimConst (PrimPi ty) = codegenPi ty
codegenPrim1 :: PrimFun f -> C.Exp -> C.Exp
codegenPrim1 (PrimNeg _) a = [cexp| $exp:a|]
codegenPrim1 (PrimAbs ty) a = codegenAbs ty a
codegenPrim1 (PrimSig ty) a = codegenSig ty a
codegenPrim1 (PrimBNot _) a = [cexp|~ $exp:a|]
codegenPrim1 (PrimRecip ty) a = codegenRecip ty a
codegenPrim1 (PrimSin ty) a = ccall (FloatingNumType ty `postfix` "sin") [a]
codegenPrim1 (PrimCos ty) a = ccall (FloatingNumType ty `postfix` "cos") [a]
codegenPrim1 (PrimTan ty) a = ccall (FloatingNumType ty `postfix` "tan") [a]
codegenPrim1 (PrimAsin ty) a = ccall (FloatingNumType ty `postfix` "asin") [a]
codegenPrim1 (PrimAcos ty) a = ccall (FloatingNumType ty `postfix` "acos") [a]
codegenPrim1 (PrimAtan ty) a = ccall (FloatingNumType ty `postfix` "atan") [a]
codegenPrim1 (PrimAsinh ty) a = ccall (FloatingNumType ty `postfix` "asinh") [a]
codegenPrim1 (PrimAcosh ty) a = ccall (FloatingNumType ty `postfix` "acosh") [a]
codegenPrim1 (PrimAtanh ty) a = ccall (FloatingNumType ty `postfix` "atanh") [a]
codegenPrim1 (PrimExpFloating ty) a = ccall (FloatingNumType ty `postfix` "exp") [a]
codegenPrim1 (PrimSqrt ty) a = ccall (FloatingNumType ty `postfix` "sqrt") [a]
codegenPrim1 (PrimLog ty) a = ccall (FloatingNumType ty `postfix` "log") [a]
codegenPrim1 (PrimTruncate ta tb) a = codegenTruncate ta tb a
codegenPrim1 (PrimRound ta tb) a = codegenRound ta tb a
codegenPrim1 (PrimFloor ta tb) a = codegenFloor ta tb a
codegenPrim1 (PrimCeiling ta tb) a = codegenCeiling ta tb a
codegenPrim1 PrimLNot a = [cexp| ! $exp:a|]
codegenPrim1 PrimOrd a = codegenOrd a
codegenPrim1 PrimChr a = codegenChr a
codegenPrim1 PrimBoolToInt a = codegenBoolToInt a
codegenPrim1 (PrimFromIntegral ta tb) a = codegenFromIntegral ta tb a
codegenPrim1 _ _ =
$internalError "codegenPrim1" "unknown primitive function"
codegenPrim2 :: PrimFun f -> C.Exp -> C.Exp -> C.Exp
codegenPrim2 (PrimAdd _) a b = [cexp|$exp:a + $exp:b|]
codegenPrim2 (PrimSub _) a b = [cexp|$exp:a $exp:b|]
codegenPrim2 (PrimMul _) a b = [cexp|$exp:a * $exp:b|]
codegenPrim2 (PrimQuot _) a b = [cexp|$exp:a / $exp:b|]
codegenPrim2 (PrimRem _) a b = [cexp|$exp:a % $exp:b|]
codegenPrim2 (PrimIDiv _) a b = ccall "idiv" [a,b]
codegenPrim2 (PrimMod _) a b = ccall "mod" [a,b]
codegenPrim2 (PrimBAnd _) a b = [cexp|$exp:a & $exp:b|]
codegenPrim2 (PrimBOr _) a b = [cexp|$exp:a | $exp:b|]
codegenPrim2 (PrimBXor _) a b = [cexp|$exp:a ^ $exp:b|]
codegenPrim2 (PrimBShiftL _) a b = [cexp|$exp:a << $exp:b|]
codegenPrim2 (PrimBShiftR _) a b = [cexp|$exp:a >> $exp:b|]
codegenPrim2 (PrimBRotateL _) a b = ccall "rotateL" [a,b]
codegenPrim2 (PrimBRotateR _) a b = ccall "rotateR" [a,b]
codegenPrim2 (PrimFDiv _) a b = [cexp|$exp:a / $exp:b|]
codegenPrim2 (PrimFPow ty) a b = ccall (FloatingNumType ty `postfix` "pow") [a,b]
codegenPrim2 (PrimLogBase ty) a b = codegenLogBase ty a b
codegenPrim2 (PrimAtan2 ty) a b = ccall (FloatingNumType ty `postfix` "atan2") [a,b]
codegenPrim2 (PrimLt _) a b = [cexp|$exp:a < $exp:b|]
codegenPrim2 (PrimGt _) a b = [cexp|$exp:a > $exp:b|]
codegenPrim2 (PrimLtEq _) a b = [cexp|$exp:a <= $exp:b|]
codegenPrim2 (PrimGtEq _) a b = [cexp|$exp:a >= $exp:b|]
codegenPrim2 (PrimEq _) a b = [cexp|$exp:a == $exp:b|]
codegenPrim2 (PrimNEq _) a b = [cexp|$exp:a != $exp:b|]
codegenPrim2 (PrimMax ty) a b = codegenMax ty a b
codegenPrim2 (PrimMin ty) a b = codegenMin ty a b
codegenPrim2 PrimLAnd a b = [cexp|$exp:a && $exp:b|]
codegenPrim2 PrimLOr a b = [cexp|$exp:a || $exp:b|]
codegenPrim2 _ _ _ =
$internalError "codegenPrim2" "unknown primitive function"
codegenConst :: TupleType a -> a -> [C.Exp]
codegenConst UnitTuple _ = []
codegenConst (SingleTuple ty) c = [codegenScalar ty c]
codegenConst (PairTuple ty1 ty0) (cs,c) = codegenConst ty1 cs ++ codegenConst ty0 c
codegenScalar :: ScalarType a -> a -> C.Exp
codegenScalar (NumScalarType ty) = codegenNumScalar ty
codegenScalar (NonNumScalarType ty) = codegenNonNumScalar ty
codegenNumScalar :: NumType a -> a -> C.Exp
codegenNumScalar (IntegralNumType ty) = codegenIntegralScalar ty
codegenNumScalar (FloatingNumType ty) = codegenFloatingScalar ty
codegenIntegralScalar :: IntegralType a -> a -> C.Exp
codegenIntegralScalar ty x | IntegralDict <- integralDict ty = [cexp| ( $ty:(codegenIntegralType ty) ) $exp:(cintegral x) |]
codegenFloatingScalar :: FloatingType a -> a -> C.Exp
codegenFloatingScalar (TypeFloat _) x = C.Const (C.FloatConst (shows x "f") (toRational x) noLoc) noLoc
codegenFloatingScalar (TypeCFloat _) x = C.Const (C.FloatConst (shows x "f") (toRational x) noLoc) noLoc
codegenFloatingScalar (TypeDouble _) x = C.Const (C.DoubleConst (show x) (toRational x) noLoc) noLoc
codegenFloatingScalar (TypeCDouble _) x = C.Const (C.DoubleConst (show x) (toRational x) noLoc) noLoc
codegenNonNumScalar :: NonNumType a -> a -> C.Exp
codegenNonNumScalar (TypeBool _) x = cbool x
codegenNonNumScalar (TypeChar _) x = [cexp|$char:x|]
codegenNonNumScalar (TypeCChar _) x = [cexp|$char:(chr (fromIntegral x))|]
codegenNonNumScalar (TypeCUChar _) x = [cexp|$char:(chr (fromIntegral x))|]
codegenNonNumScalar (TypeCSChar _) x = [cexp|$char:(chr (fromIntegral x))|]
codegenPi :: FloatingType a -> C.Exp
codegenPi ty | FloatingDict <- floatingDict ty = codegenFloatingScalar ty pi
codegenMinBound :: BoundedType a -> C.Exp
codegenMinBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = codegenIntegralScalar ty minBound
codegenMinBound (NonNumBoundedType ty) | NonNumDict <- nonNumDict ty = codegenNonNumScalar ty minBound
codegenMaxBound :: BoundedType a -> C.Exp
codegenMaxBound (IntegralBoundedType ty) | IntegralDict <- integralDict ty = codegenIntegralScalar ty maxBound
codegenMaxBound (NonNumBoundedType ty) | NonNumDict <- nonNumDict ty = codegenNonNumScalar ty maxBound
codegenAbs :: NumType a -> C.Exp -> C.Exp
codegenAbs (FloatingNumType ty) x = ccall (FloatingNumType ty `postfix` "fabs") [x]
codegenAbs (IntegralNumType ty) x =
case ty of
TypeWord _ -> x
TypeWord8 _ -> x
TypeWord16 _ -> x
TypeWord32 _ -> x
TypeWord64 _ -> x
TypeCUShort _ -> x
TypeCUInt _ -> x
TypeCULong _ -> x
TypeCULLong _ -> x
_ -> ccall "abs" [x]
codegenSig :: NumType a -> C.Exp -> C.Exp
codegenSig (IntegralNumType ty) = codegenIntegralSig ty
codegenSig (FloatingNumType ty) = codegenFloatingSig ty
codegenIntegralSig :: IntegralType a -> C.Exp -> C.Exp
codegenIntegralSig ty x = [cexp|$exp:x == $exp:zero ? $exp:zero : $exp:(ccall "copysign" [one,x]) |]
where
zero | IntegralDict <- integralDict ty = codegenIntegralScalar ty 0
one | IntegralDict <- integralDict ty = codegenIntegralScalar ty 1
codegenFloatingSig :: FloatingType a -> C.Exp -> C.Exp
codegenFloatingSig ty x =
[cexp|$exp:x == $exp:zero
? $exp:zero
: $exp:(ccall (FloatingNumType ty `postfix` "copysign") [one,x]) |]
where
zero | FloatingDict <- floatingDict ty = codegenFloatingScalar ty 0
one | FloatingDict <- floatingDict ty = codegenFloatingScalar ty 1
codegenRecip :: FloatingType a -> C.Exp -> C.Exp
codegenRecip ty x | FloatingDict <- floatingDict ty = [cexp|$exp:(codegenFloatingScalar ty 1) / $exp:x|]
codegenLogBase :: FloatingType a -> C.Exp -> C.Exp -> C.Exp
codegenLogBase ty x y = let a = ccall (FloatingNumType ty `postfix` "log") [x]
b = ccall (FloatingNumType ty `postfix` "log") [y]
in
[cexp|$exp:b / $exp:a|]
codegenMin :: ScalarType a -> C.Exp -> C.Exp -> C.Exp
codegenMin (NumScalarType ty@(IntegralNumType _)) a b = ccall (ty `postfix` "min") [a,b]
codegenMin (NumScalarType ty@(FloatingNumType _)) a b = ccall (ty `postfix` "fmin") [a,b]
codegenMin (NonNumScalarType _) a b =
let ty = scalarType :: ScalarType Int32
in codegenMin ty (ccast ty a) (ccast ty b)
codegenMax :: ScalarType a -> C.Exp -> C.Exp -> C.Exp
codegenMax (NumScalarType ty@(IntegralNumType _)) a b = ccall (ty `postfix` "max") [a,b]
codegenMax (NumScalarType ty@(FloatingNumType _)) a b = ccall (ty `postfix` "fmax") [a,b]
codegenMax (NonNumScalarType _) a b =
let ty = scalarType :: ScalarType Int32
in codegenMax ty (ccast ty a) (ccast ty b)
codegenOrd :: C.Exp -> C.Exp
codegenOrd = ccast (scalarType :: ScalarType Int)
codegenChr :: C.Exp -> C.Exp
codegenChr = ccast (scalarType :: ScalarType Char)
codegenBoolToInt :: C.Exp -> C.Exp
codegenBoolToInt = ccast (scalarType :: ScalarType Int)
codegenFromIntegral :: IntegralType a -> NumType b -> C.Exp -> C.Exp
codegenFromIntegral _ ty = ccast (NumScalarType ty)
codegenTruncate :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenTruncate ta tb x
= ccast (NumScalarType (IntegralNumType tb))
$ ccall (FloatingNumType ta `postfix` "trunc") [x]
codegenRound :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenRound ta tb x
= ccast (NumScalarType (IntegralNumType tb))
$ ccall (FloatingNumType ta `postfix` "round") [x]
codegenFloor :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenFloor ta tb x
= ccast (NumScalarType (IntegralNumType tb))
$ ccall (FloatingNumType ta `postfix` "floor") [x]
codegenCeiling :: FloatingType a -> IntegralType b -> C.Exp -> C.Exp
codegenCeiling ta tb x
= ccast (NumScalarType (IntegralNumType tb))
$ ccall (FloatingNumType ta `postfix` "ceil") [x]
ccast :: ScalarType a -> C.Exp -> C.Exp
ccast ty x = [cexp|($ty:(codegenScalarType ty)) $exp:x|]
ccastTup :: TupleType e -> [C.Exp] -> [C.Exp]
ccastTup ty = fst . travTup ty
where
travTup :: TupleType e -> [C.Exp] -> ([C.Exp],[C.Exp])
travTup UnitTuple xs = ([], xs)
travTup (SingleTuple ty') (x:xs) = ([ccast ty' x], xs)
travTup (PairTuple l r) xs = let
(ls, xs' ) = travTup l xs
(rs, xs'') = travTup r xs'
in (ls ++ rs, xs'')
travTup _ _ = $internalError "ccastTup" "not enough expressions to match type"
postfix :: NumType a -> String -> String
postfix (FloatingNumType (TypeFloat _)) x = x ++ "f"
postfix (FloatingNumType (TypeCFloat _)) x = x ++ "f"
postfix _ x = x