module Feldspar.Primitive.Backend.C where
import Data.Complex
import Data.Constraint (Dict (..))
import Data.Proxy
import Language.C.Quote.C
import qualified Language.C.Syntax as C
import Language.C.Monad
import Language.Embedded.Backend.C
import Language.Syntactic
import Feldspar.Primitive.Representation
viewLitPrim :: ASTF (Primitive :&: PrimTypeRep) a -> Maybe a
viewLitPrim (Sym (Lit a :&: _)) = Just a
viewLitPrim (Sym (Pi :&: _)) = Just pi
viewLitPrim _ = Nothing
instance CompTypeClass PrimType'
where
compType _ (_ :: proxy a) = case primTypeRep :: PrimTypeRep a of
BoolT -> addInclude "<stdbool.h>" >> return [cty| typename bool |]
Int8T -> addInclude "<stdint.h>" >> return [cty| typename int8_t |]
Int16T -> addInclude "<stdint.h>" >> return [cty| typename int16_t |]
Int32T -> addInclude "<stdint.h>" >> return [cty| typename int32_t |]
Int64T -> addInclude "<stdint.h>" >> return [cty| typename int64_t |]
Word8T -> addInclude "<stdint.h>" >> return [cty| typename uint8_t |]
Word16T -> addInclude "<stdint.h>" >> return [cty| typename uint16_t |]
Word32T -> addInclude "<stdint.h>" >> return [cty| typename uint32_t |]
Word64T -> addInclude "<stdint.h>" >> return [cty| typename uint64_t |]
FloatT -> return [cty| float |]
DoubleT -> return [cty| double |]
ComplexFloatT -> addInclude "<tgmath.h>" >> return [cty| float _Complex |]
ComplexDoubleT -> addInclude "<tgmath.h>" >> return [cty| double _Complex |]
compLit _ a = case primTypeOf a of
BoolT -> do addInclude "<stdbool.h>"
return $ if a then [cexp| true |] else [cexp| false |]
Int8T -> return [cexp| $a |]
Int16T -> return [cexp| $a |]
Int32T -> return [cexp| $a |]
Int64T -> return [cexp| $a |]
Word8T -> return [cexp| $a |]
Word16T -> return [cexp| $a |]
Word32T -> return [cexp| $a |]
Word64T -> return [cexp| $a |]
FloatT -> return [cexp| $a |]
DoubleT -> return [cexp| $a |]
ComplexFloatT -> return $ compComplexLit a
ComplexDoubleT -> return $ compComplexLit a
compComplexLit :: (Eq a, Num a, ToExp a) => Complex a -> C.Exp
compComplexLit (r :+ 0) = [cexp| $r |]
compComplexLit (0 :+ i) = [cexp| $i * I |]
compComplexLit (r :+ i) = [cexp| $r + $i * I |]
addTagMacro :: MonadC m => m ()
addTagMacro = addGlobal [cedecl|$esc:("#define TAG(tag,exp) (exp)")|]
compUnOp :: MonadC m => C.UnOp -> ASTF PrimDomain a -> m C.Exp
compUnOp op a = do
a' <- compPrim $ Prim a
return $ C.UnOp op a' mempty
compBinOp :: MonadC m =>
C.BinOp -> ASTF PrimDomain a -> ASTF PrimDomain b -> m C.Exp
compBinOp op a b = do
a' <- compPrim $ Prim a
b' <- compPrim $ Prim b
return $ C.BinOp op a' b' mempty
compFun :: MonadC m => String -> Args (AST PrimDomain) sig -> m C.Exp
compFun fun args = do
as <- sequence $ listArgs (compPrim . Prim) args
return [cexp| $id:fun($args:as) |]
compAbs :: MonadC m => PrimTypeRep a -> ASTF PrimDomain a -> m C.Exp
compAbs t a = case viewPrimTypeRep t of
PrimTypeBool -> error "compAbs: type BoolT not supported"
PrimTypeIntWord (IntType _) -> addInclude "<stdlib.h>" >> compFun "abs" (a :* Nil)
PrimTypeIntWord (WordType _) -> compPrim $ Prim a
_ -> addInclude "<tgmath.h>" >> compFun "fabs" (a :* Nil)
complexSign_def = [cedecl|
double _Complex feld_complexSign(double _Complex c) {
double z = cabs(c);
if (z == 0) {
return 0;
} else {
return (creal(c)/z + I*(cimag(c)/z));
}
}
|]
complexSignf_def = [cedecl|
float _Complex feld_complexSignf(float _Complex c) {
float z = cabsf(c);
if (z == 0) {
return 0;
} else {
return (crealf(c)/z + I*(cimagf(c)/z));
}
}
|]
compSign :: MonadC m => PrimTypeRep a -> ASTF PrimDomain a -> m C.Exp
compSign t a = case viewPrimTypeRep t of
PrimTypeBool -> error "compSign: type BoolT not supported"
PrimTypeIntWord (WordType _) -> do
addTagMacro
a' <- compPrim $ Prim a
return [cexp| TAG("signum", $a' > 0) |]
PrimTypeIntWord (IntType _) -> do
addTagMacro
a' <- compPrim $ Prim a
return [cexp| TAG("signum", ($a' > 0) ($a' < 0)) |]
PrimTypeFloating FloatType -> do
addTagMacro
a' <- compPrim $ Prim a
return [cexp| TAG("signum", (float) (($a' > 0) ($a' < 0))) |]
PrimTypeFloating DoubleType -> do
addTagMacro
a' <- compPrim $ Prim a
return [cexp| TAG("signum", (double) (($a' > 0) ($a' < 0))) |]
PrimTypeComplex ComplexDoubleType -> do
addInclude "<tgmath.h>"
addGlobal complexSign_def
a' <- compPrim $ Prim a
return [cexp| feld_complexSign($a') |]
PrimTypeComplex ComplexFloatType -> do
addInclude "<complex.h>"
addGlobal complexSignf_def
a' <- compPrim $ Prim a
return [cexp| feld_complexSignf($a') |]
compCast :: MonadC m => PrimTypeRep a -> ASTF PrimDomain b -> m C.Exp
compCast t a = compPrim (Prim a) >>= compCastExp t
compCastExp :: MonadC m => PrimTypeRep a -> C.Exp -> m C.Exp
compCastExp t a
| Dict <- witPrimType t = do
t' <- compType (Proxy :: Proxy PrimType') t
return [cexp|($ty:t') $a|]
compRound :: (PrimType' a, Num a, RealFrac b, MonadC m) =>
PrimTypeRep a -> ASTF PrimDomain b -> m C.Exp
compRound t a = do
addInclude "<tgmath.h>"
rounded <- case viewPrimTypeRep t of
PrimTypeIntWord _ -> compFun "lround" (a :* Nil)
PrimTypeFloating _ -> compFun "round" (a :* Nil)
PrimTypeComplex _ -> compFun "round" (a :* Nil)
_ -> error $ "compRound: type " ++ show t ++ " not supported"
compCastExp t rounded
div_def = [cedecl|
int feld_div(int x, int y) {
int q = x/y;
int r = x%y;
if ((r!=0) && ((r<0) != (y<0))) --q;
return q;
}
|]
ldiv_def = [cedecl|
long int feld_ldiv(long int x, long int y) {
int q = x/y;
int r = x%y;
if ((r!=0) && ((r<0) != (y<0))) --q;
return q;
}
|]
mod_def = [cedecl|
int feld_mod(int x, int y) {
int r = x%y;
if ((r!=0) && ((r<0) != (y<0))) { r += y; }
return r;
}
|]
lmod_def = [cedecl|
long int feld_lmod(long int x, long int y) {
int r = x%y;
if ((r!=0) && ((r<0) != (y<0))) { r += y; }
return r;
}
|]
compDiv :: MonadC m =>
PrimTypeRep a -> ASTF PrimDomain a -> ASTF PrimDomain b -> m C.Exp
compDiv t a b = case viewPrimTypeRep t of
PrimTypeIntWord (WordType _) -> compBinOp C.Div a b
PrimTypeIntWord (IntType Int64Type) -> do
addGlobal ldiv_def
compFun "feld_ldiv" (a :* b :* Nil)
PrimTypeIntWord _ -> do
addGlobal div_def
compFun "feld_div" (a :* b :* Nil)
_ -> error $ "compDiv: type " ++ show t ++ " not supported"
compMod :: MonadC m =>
PrimTypeRep a -> ASTF PrimDomain a -> ASTF PrimDomain b -> m C.Exp
compMod t a b = case viewPrimTypeRep t of
PrimTypeIntWord (WordType _) -> compBinOp C.Mod a b
PrimTypeIntWord (IntType Int64Type) -> do
addGlobal lmod_def
compFun "feld_lmod" (a :* b :* Nil)
PrimTypeIntWord _ -> do
addGlobal mod_def
compFun "feld_mod" (a :* b :* Nil)
_ -> error $ "compMod: type " ++ show t ++ " not supported"
compPrim :: MonadC m => Prim a -> m C.Exp
compPrim = simpleMatch (\(s :&: t) -> go t s) . unPrim
where
go :: forall m sig . MonadC m
=> PrimTypeRep (DenResult sig)
-> Primitive sig
-> Args (AST PrimDomain) sig
-> m C.Exp
go _ (FreeVar v) Nil = touchVar v >> return [cexp| $id:v |]
go t (Lit a) Nil
| Dict <- witPrimType t
= compLit (Proxy :: Proxy PrimType') a
go _ Add (a :* b :* Nil) = compBinOp C.Add a b
go _ Sub (a :* b :* Nil) = compBinOp C.Sub a b
go _ Mul (a :* b :* Nil) = compBinOp C.Mul a b
go _ Neg (a :* Nil) = compUnOp C.Negate a
go t Abs (a :* Nil) = compAbs t a
go t Sign (a :* Nil) = compSign t a
go _ Quot (a :* b :* Nil) = compBinOp C.Div a b
go _ Rem (a :* b :* Nil) = compBinOp C.Mod a b
go t Div (a :* b :* Nil) = compDiv t a b
go t Mod (a :* b :* Nil) = compMod t a b
go _ FDiv (a :* b :* Nil) = compBinOp C.Div a b
go _ Pi Nil = addGlobal pi_def >> return [cexp| FELD_PI |]
where pi_def = [cedecl|$esc:("#define FELD_PI 3.141592653589793")|]
go _ Exp args = addInclude "<tgmath.h>" >> compFun "exp" args
go _ Log args = addInclude "<tgmath.h>" >> compFun "log" args
go _ Sqrt args = addInclude "<tgmath.h>" >> compFun "sqrt" args
go _ Pow args = addInclude "<tgmath.h>" >> compFun "pow" args
go _ Sin args = addInclude "<tgmath.h>" >> compFun "sin" args
go _ Cos args = addInclude "<tgmath.h>" >> compFun "cos" args
go _ Tan args = addInclude "<tgmath.h>" >> compFun "tan" args
go _ Asin args = addInclude "<tgmath.h>" >> compFun "asin" args
go _ Acos args = addInclude "<tgmath.h>" >> compFun "acos" args
go _ Atan args = addInclude "<tgmath.h>" >> compFun "atan" args
go _ Sinh args = addInclude "<tgmath.h>" >> compFun "sinh" args
go _ Cosh args = addInclude "<tgmath.h>" >> compFun "cosh" args
go _ Tanh args = addInclude "<tgmath.h>" >> compFun "tanh" args
go _ Asinh args = addInclude "<tgmath.h>" >> compFun "asinh" args
go _ Acosh args = addInclude "<tgmath.h>" >> compFun "acosh" args
go _ Atanh args = addInclude "<tgmath.h>" >> compFun "atanh" args
go _ Complex (a :* b :* Nil) = do
addInclude "<tgmath.h>"
a' <- compPrim $ Prim a
b' <- compPrim $ Prim b
return $ case (viewLitPrim a, viewLitPrim b) of
(Just 0, _) -> [cexp| I*$b' |]
(_, Just 0) -> [cexp| $a' |]
_ -> [cexp| $a' + I*$b' |]
go _ Polar (m :* p :* Nil)
| Just 0 <- viewLitPrim m = return [cexp| 0 |]
| Just 0 <- viewLitPrim p = do
m' <- compPrim $ Prim m
return [cexp| $m' |]
| Just 1 <- viewLitPrim m = do
p' <- compPrim $ Prim p
return [cexp| exp(I*$p') |]
| otherwise = do
m' <- compPrim $ Prim m
p' <- compPrim $ Prim p
return [cexp| $m' * exp(I*$p') |]
go _ Real args = addInclude "<tgmath.h>" >> compFun "creal" args
go _ Imag args = addInclude "<tgmath.h>" >> compFun "cimag" args
go _ Magnitude args = addInclude "<tgmath.h>" >> compFun "cabs" args
go _ Phase args = addInclude "<tgmath.h>" >> compFun "carg" args
go _ Conjugate args = addInclude "<tgmath.h>" >> compFun "conj" args
go t I2N (a :* Nil) = compCast t a
go t I2B (a :* Nil) = compCast t a
go t B2I (a :* Nil) = compCast t a
go t Round (a :* Nil) = compRound t a
go _ Not (a :* Nil) = compUnOp C.Lnot a
go _ And (a :* b :* Nil) = compBinOp C.Land a b
go _ Or (a :* b :* Nil) = compBinOp C.Lor a b
go _ Eq (a :* b :* Nil) = compBinOp C.Eq a b
go _ NEq (a :* b :* Nil) = compBinOp C.Ne a b
go _ Lt (a :* b :* Nil) = compBinOp C.Lt a b
go _ Gt (a :* b :* Nil) = compBinOp C.Gt a b
go _ Le (a :* b :* Nil) = compBinOp C.Le a b
go _ Ge (a :* b :* Nil) = compBinOp C.Ge a b
go _ BitAnd (a :* b :* Nil) = compBinOp C.And a b
go _ BitOr (a :* b :* Nil) = compBinOp C.Or a b
go _ BitXor (a :* b :* Nil) = compBinOp C.Xor a b
go _ BitCompl (a :* Nil) = compUnOp C.Not a
go _ ShiftL (a :* b :* Nil) = compBinOp C.Lsh a b
go _ ShiftR (a :* b :* Nil) = compBinOp C.Rsh a b
go _ (ArrIx arr) (i :* Nil) = do
i' <- compPrim $ Prim i
touchVar arr
return [cexp| $id:arr[$i'] |]
go _ Cond (c :* t :* f :* Nil) = do
c' <- compPrim $ Prim c
t' <- compPrim $ Prim t
f' <- compPrim $ Prim f
return $ C.Cond c' t' f' mempty
go _ s _ = error $ "compPrim: no handling of symbol " ++ renderSym s
instance CompExp Prim where compExp = compPrim