{-# LANGUAGE QuasiQuotes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Futhark.CodeGen.Backends.SimpleRep
( tupleField,
funName,
defaultMemBlockType,
intTypeToCType,
primTypeToCType,
primStorageType,
primAPIType,
arrayName,
opaqueName,
isValidCName,
escapeName,
toStorage,
fromStorage,
cproduct,
csum,
allEqual,
allTrue,
scalarToPrim,
cScalarDefs,
storageSize,
storeValueHeader,
loadValueHeader,
)
where
import Control.Monad (void)
import Data.Char (isAlpha, isAlphaNum, isDigit)
import Data.Text qualified as T
import Data.Void (Void)
import Futhark.CodeGen.ImpCode
import Futhark.CodeGen.RTS.C (scalarF16H, scalarH)
import Futhark.Util (hashText, showText, zEncodeText)
import Language.C.Quote.C qualified as C
import Language.C.Syntax qualified as C
import Text.Megaparsec
import Text.Megaparsec.Char (space)
intTypeToCType :: IntType -> C.Type
intTypeToCType :: IntType -> Type
intTypeToCType IntType
Int8 = [C.cty|typename int8_t|]
intTypeToCType IntType
Int16 = [C.cty|typename int16_t|]
intTypeToCType IntType
Int32 = [C.cty|typename int32_t|]
intTypeToCType IntType
Int64 = [C.cty|typename int64_t|]
uintTypeToCType :: IntType -> C.Type
uintTypeToCType :: IntType -> Type
uintTypeToCType IntType
Int8 = [C.cty|typename uint8_t|]
uintTypeToCType IntType
Int16 = [C.cty|typename uint16_t|]
uintTypeToCType IntType
Int32 = [C.cty|typename uint32_t|]
uintTypeToCType IntType
Int64 = [C.cty|typename uint64_t|]
primTypeToCType :: PrimType -> C.Type
primTypeToCType :: PrimType -> Type
primTypeToCType (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primTypeToCType (FloatType FloatType
Float16) = [C.cty|typename f16|]
primTypeToCType (FloatType FloatType
Float32) = [C.cty|float|]
primTypeToCType (FloatType FloatType
Float64) = [C.cty|double|]
primTypeToCType PrimType
Bool = [C.cty|typename bool|]
primTypeToCType PrimType
Unit = [C.cty|typename bool|]
primStorageType :: PrimType -> C.Type
primStorageType :: PrimType -> Type
primStorageType (FloatType FloatType
Float16) = [C.cty|typename uint16_t|]
primStorageType PrimType
t = PrimType -> Type
primTypeToCType PrimType
t
primAPIType :: Signedness -> PrimType -> C.Type
primAPIType :: Signedness -> PrimType -> Type
primAPIType Signedness
Unsigned (IntType IntType
t) = IntType -> Type
uintTypeToCType IntType
t
primAPIType Signedness
Signed (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primAPIType Signedness
_ PrimType
t = PrimType -> Type
primStorageType PrimType
t
toStorage :: PrimType -> C.Exp -> C.Exp
toStorage :: PrimType -> Exp -> Exp
toStorage (FloatType FloatType
Float16) Exp
e = [C.cexp|futrts_to_bits16($exp:e)|]
toStorage PrimType
_ Exp
e = Exp
e
fromStorage :: PrimType -> C.Exp -> C.Exp
fromStorage :: PrimType -> Exp -> Exp
fromStorage (FloatType FloatType
Float16) Exp
e = [C.cexp|futrts_from_bits16($exp:e)|]
fromStorage PrimType
_ Exp
e = Exp
e
tupleField :: Int -> String
tupleField :: Int -> String
tupleField Int
i = String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
funName :: Name -> T.Text
funName :: Name -> Text
funName = (Text
"futrts_" <>) (Text -> Text) -> (Name -> Text) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (Name -> Text) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Text
nameToText
defaultMemBlockType :: C.Type
defaultMemBlockType :: Type
defaultMemBlockType = [C.cty|unsigned char*|]
arrayName :: PrimType -> Signedness -> Int -> T.Text
arrayName :: PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank =
Bool -> PrimType -> Text
prettySigned (Signedness
signed Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Unsigned) PrimType
pt Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
rank Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"d"
isValidCName :: T.Text -> Bool
isValidCName :: Text -> Bool
isValidCName = Bool -> ((Char, Text) -> Bool) -> Maybe (Char, Text) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Char, Text) -> Bool
check (Maybe (Char, Text) -> Bool)
-> (Text -> Maybe (Char, Text)) -> Text -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe (Char, Text)
T.uncons
where
check :: (Char, Text) -> Bool
check (Char
c, Text
cs) = Char -> Bool
isAlpha Char
c Bool -> Bool -> Bool
&& (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
constituent Text
cs
constituent :: Char -> Bool
constituent Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'
escapeName :: T.Text -> T.Text
escapeName :: Text -> Text
escapeName Text
v
| Text -> Bool
isValidCName Text
v = Text
v
| Bool
otherwise = Text -> Text
zEncodeText Text
v
valid :: T.Text -> Bool
valid :: Text -> Bool
valid Text
s =
HasCallStack => Text -> Char
Text -> Char
T.head Text
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'_'
Bool -> Bool -> Bool
&& Bool -> Bool
not (Char -> Bool
isDigit (Char -> Bool) -> Char -> Bool
forall a b. (a -> b) -> a -> b
$ HasCallStack => Text -> Char
Text -> Char
T.head Text
s)
Bool -> Bool -> Bool
&& (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
ok Text
s
where
ok :: Char -> Bool
ok Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'
findPrettyName :: T.Text -> Either String T.Text
findPrettyName :: Text -> Either String Text
findPrettyName =
(ParseErrorBundle Text Void -> Either String Text)
-> (Text -> Either String Text)
-> Either (ParseErrorBundle Text Void) Text
-> Either String Text
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Either String Text
forall a b. a -> Either a b
Left (String -> Either String Text)
-> (ParseErrorBundle Text Void -> String)
-> ParseErrorBundle Text Void
-> Either String Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseErrorBundle Text Void -> String
forall s e.
(VisualStream s, TraversableStream s, ShowErrorComponent e) =>
ParseErrorBundle s e -> String
errorBundlePretty) Text -> Either String Text
forall a b. b -> Either a b
Right (Either (ParseErrorBundle Text Void) Text -> Either String Text)
-> (Text -> Either (ParseErrorBundle Text Void) Text)
-> Text
-> Either String Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parsec Void Text Text
-> String -> Text -> Either (ParseErrorBundle Text Void) Text
forall e s a.
Parsec e s a -> String -> s -> Either (ParseErrorBundle s e) a
parse (Parsec Void Text Text
p Parsec Void Text Text
-> ParsecT Void Text Identity () -> Parsec Void Text Text
forall a b.
ParsecT Void Text Identity a
-> ParsecT Void Text Identity b -> ParsecT Void Text Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ParsecT Void Text Identity ()
forall e s (m :: * -> *). MonadParsec e s m => m ()
eof) String
"type name"
where
p :: Parsec Void T.Text T.Text
p :: Parsec Void Text Text
p = [Parsec Void Text Text] -> Parsec Void Text Text
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, Alternative m) =>
f (m a) -> m a
choice [Parsec Void Text Text
pArr, Parsec Void Text Text
pTup, Parsec Void Text Text
pAtom]
pArr :: Parsec Void Text Text
pArr = do
[Text]
dims <- Parsec Void Text Text -> ParsecT Void Text Identity [Text]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
some Parsec Void Text Text
"[]"
((Text
"arr" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showText ([Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
dims) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"d_") <>) (Text -> Text) -> Parsec Void Text Text -> Parsec Void Text Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parsec Void Text Text
p
pTup :: Parsec Void Text Text
pTup = Parsec Void Text Text
-> Parsec Void Text Text
-> Parsec Void Text Text
-> Parsec Void Text Text
forall (m :: * -> *) open close a.
Applicative m =>
m open -> m close -> m a -> m a
between Parsec Void Text Text
"(" Parsec Void Text Text
")" (Parsec Void Text Text -> Parsec Void Text Text)
-> Parsec Void Text Text -> Parsec Void Text Text
forall a b. (a -> b) -> a -> b
$ do
[Text]
ts <- Parsec Void Text Text
p Parsec Void Text Text
-> ParsecT Void Text Identity ()
-> ParsecT Void Text Identity [Text]
forall (m :: * -> *) a sep. MonadPlus m => m a -> m sep -> m [a]
`sepBy` ParsecT Void Text Identity ()
pComma
Text -> Parsec Void Text Text
forall a. a -> ParsecT Void Text Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> Parsec Void Text Text) -> Text -> Parsec Void Text Text
forall a b. (a -> b) -> a -> b
$ Text
"tup" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showText ([Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
ts) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> [Text] -> Text
T.intercalate Text
"_" [Text]
ts
pAtom :: Parsec Void Text Text
pAtom = String -> Text
T.pack (String -> Text)
-> ParsecT Void Text Identity String -> Parsec Void Text Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ParsecT Void Text Identity Char
-> ParsecT Void Text Identity String
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
some ((Token Text -> Bool) -> ParsecT Void Text Identity (Token Text)
forall e s (m :: * -> *).
MonadParsec e s m =>
(Token s -> Bool) -> m (Token s)
satisfy (Char -> String -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (String
"[]{}()," :: String)))
pComma :: ParsecT Void Text Identity ()
pComma = Parsec Void Text Text -> ParsecT Void Text Identity ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parsec Void Text Text -> ParsecT Void Text Identity ())
-> Parsec Void Text Text -> ParsecT Void Text Identity ()
forall a b. (a -> b) -> a -> b
$ Parsec Void Text Text
"," Parsec Void Text Text
-> ParsecT Void Text Identity () -> Parsec Void Text Text
forall a b.
ParsecT Void Text Identity a
-> ParsecT Void Text Identity b -> ParsecT Void Text Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ParsecT Void Text Identity ()
forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
m ()
space
opaqueName :: Name -> T.Text
opaqueName :: Name -> Text
opaqueName Name
"()" = Text
"opaque_unit"
opaqueName Name
s
| Right Text
v <- Text -> Either String Text
findPrettyName Text
s',
Text -> Bool
valid Text
v =
Text
"opaque_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
v
| Text -> Bool
valid Text
s' = Text
"opaque_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
s'
where
s' :: Text
s' = Name -> Text
nameToText Name
s
opaqueName Name
s = Text
"opaque_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
hashText (Name -> Text
nameToText Name
s)
scalarToPrim :: T.Text -> (Signedness, PrimType)
scalarToPrim :: Text -> (Signedness, PrimType)
scalarToPrim Text
"bool" = (Signedness
Signed, PrimType
Bool)
scalarToPrim Text
"i8" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int8)
scalarToPrim Text
"i16" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int16)
scalarToPrim Text
"i32" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int32)
scalarToPrim Text
"i64" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int64)
scalarToPrim Text
"u8" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int8)
scalarToPrim Text
"u16" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int16)
scalarToPrim Text
"u32" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int32)
scalarToPrim Text
"u64" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int64)
scalarToPrim Text
"f16" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float16)
scalarToPrim Text
"f32" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float32)
scalarToPrim Text
"f64" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float64)
scalarToPrim Text
tname = String -> (Signedness, PrimType)
forall a. HasCallStack => String -> a
error (String -> (Signedness, PrimType))
-> String -> (Signedness, PrimType)
forall a b. (a -> b) -> a -> b
$ String
"scalarToPrim: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
tname
cproduct :: [C.Exp] -> C.Exp
cproduct :: [Exp] -> Exp
cproduct [] = [C.cexp|1|]
cproduct (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
where
mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x * $exp:y|]
csum :: [C.Exp] -> C.Exp
csum :: [Exp] -> Exp
csum [] = [C.cexp|0|]
csum (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
where
mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x + $exp:y|]
allTrue :: [C.Exp] -> C.Exp
allTrue :: [Exp] -> Exp
allTrue [] = [C.cexp|true|]
allTrue [Exp
x] = Exp
x
allTrue (Exp
x : [Exp]
xs) = [C.cexp|$exp:x && $exp:(allTrue xs)|]
allEqual :: [C.Exp] -> C.Exp
allEqual :: [Exp] -> Exp
allEqual [Exp
x, Exp
y] = [C.cexp|$exp:x == $exp:y|]
allEqual (Exp
x : Exp
y : [Exp]
xs) = [C.cexp|$exp:x == $exp:y && $exp:(allEqual(y:xs))|]
allEqual [Exp]
_ = [C.cexp|true|]
instance C.ToIdent Name where
toIdent :: Name -> SrcLoc -> Id
toIdent = Text -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Text -> SrcLoc -> Id) -> (Name -> Text) -> Name -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (Name -> Text) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Text
nameToText
instance C.ToIdent T.Text where
toIdent :: Text -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (Text -> String) -> Text -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack
instance C.ToIdent VName where
toIdent :: VName -> SrcLoc -> Id
toIdent = Text -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Text -> SrcLoc -> Id) -> (VName -> Text) -> VName -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (VName -> Text) -> VName -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Text
forall a. Pretty a => a -> Text
prettyText
instance C.ToExp VName where
toExp :: VName -> SrcLoc -> Exp
toExp VName
v SrcLoc
_ = [C.cexp|$id:v|]
instance C.ToExp IntValue where
toExp :: IntValue -> SrcLoc -> Exp
toExp (Int8Value Int8
k) SrcLoc
_ = [C.cexp|(typename int8_t)$int:k|]
toExp (Int16Value Int16
k) SrcLoc
_ = [C.cexp|(typename int16_t)$int:k|]
toExp (Int32Value Int32
k) SrcLoc
_ = [C.cexp|$int:k|]
toExp (Int64Value Int64
k) SrcLoc
_ = [C.cexp|(typename int64_t)$int:k|]
instance C.ToExp FloatValue where
toExp :: FloatValue -> SrcLoc -> Exp
toExp (Float16Value Half
x) SrcLoc
_
| Half -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Half
x =
if Half
x Half -> Half -> Bool
forall a. Ord a => a -> a -> Bool
> Half
0 then [C.cexp|(typename f16)INFINITY|] else [C.cexp|(typename f16)-INFINITY|]
| Half -> Bool
forall a. RealFloat a => a -> Bool
isNaN Half
x =
[C.cexp|(typename f16)NAN|]
| Bool
otherwise =
[C.cexp|(typename f16)$float:(fromRational (toRational x))|]
toExp (Float32Value Float
x) SrcLoc
_
| Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Float
x =
if Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
> Float
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
| Float -> Bool
forall a. RealFloat a => a -> Bool
isNaN Float
x =
[C.cexp|NAN|]
| Bool
otherwise =
[C.cexp|$float:x|]
toExp (Float64Value Double
x) SrcLoc
_
| Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x =
if Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
| Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x =
[C.cexp|NAN|]
| Bool
otherwise =
[C.cexp|$double:x|]
instance C.ToExp PrimValue where
toExp :: PrimValue -> SrcLoc -> Exp
toExp (IntValue IntValue
v) = IntValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp IntValue
v
toExp (FloatValue FloatValue
v) = FloatValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp FloatValue
v
toExp (BoolValue Bool
True) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
toExp (BoolValue Bool
False) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
toExp PrimValue
UnitValue = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
instance C.ToExp SubExp where
toExp :: SubExp -> SrcLoc -> Exp
toExp (Var VName
v) = VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
v
toExp (Constant PrimValue
c) = PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
c
cScalarDefs :: T.Text
cScalarDefs :: Text
cScalarDefs = Text
scalarH Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
scalarF16H
storageSize :: PrimType -> Int -> C.Exp -> C.Exp
storageSize :: PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape =
[C.cexp|$int:header_size +
$int:rank * sizeof(typename int64_t) +
$exp:(cproduct dims) * sizeof($ty:(primStorageType pt))|]
where
header_size :: Int
header_size :: Int
header_size = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4
dims :: [Exp]
dims = [[C.cexp|$exp:shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
typeStr :: Signedness -> PrimType -> String
typeStr :: Signedness -> PrimType -> String
typeStr Signedness
sign PrimType
pt =
case (Signedness
sign, PrimType
pt) of
(Signedness
_, PrimType
Bool) -> String
"bool"
(Signedness
_, PrimType
Unit) -> String
"bool"
(Signedness
_, FloatType FloatType
Float16) -> String
" f16"
(Signedness
_, FloatType FloatType
Float32) -> String
" f32"
(Signedness
_, FloatType FloatType
Float64) -> String
" f64"
(Signedness
Signed, IntType IntType
Int8) -> String
" i8"
(Signedness
Signed, IntType IntType
Int16) -> String
" i16"
(Signedness
Signed, IntType IntType
Int32) -> String
" i32"
(Signedness
Signed, IntType IntType
Int64) -> String
" i64"
(Signedness
Unsigned, IntType IntType
Int8) -> String
" u8"
(Signedness
Unsigned, IntType IntType
Int16) -> String
" u16"
(Signedness
Unsigned, IntType IntType
Int32) -> String
" u32"
(Signedness
Unsigned, IntType IntType
Int64) -> String
" u64"
storeValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
Signedness
sign PrimType
pt Int
rank Exp
shape Exp
dest =
[C.cstms|
*$exp:dest++ = 'b';
*$exp:dest++ = 2;
*$exp:dest++ = $int:rank;
memcpy($exp:dest, $string:(typeStr sign pt), 4);
$exp:dest += 4;
$stms:copy_shape
|]
where
copy_shape :: [Stm]
copy_shape
| Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
| Bool
otherwise =
[C.cstms|
memcpy($exp:dest, $exp:shape, $int:rank*sizeof(typename int64_t));
$exp:dest += $int:rank*sizeof(typename int64_t);|]
loadValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
Signedness
sign PrimType
pt Int
rank Exp
shape Exp
src =
[C.cstms|
err |= (*$exp:src++ != 'b');
err |= (*$exp:src++ != 2);
err |= (*$exp:src++ != $exp:rank);
err |= (memcmp($exp:src, $string:(typeStr sign pt), 4) != 0);
$exp:src += 4;
if (err == 0) {
$stms:load_shape
$exp:src += $int:rank*sizeof(typename int64_t);
}|]
where
load_shape :: [Stm]
load_shape
| Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
| Bool
otherwise = [C.cstms|memcpy($exp:shape, src, $int:rank*sizeof(typename int64_t));|]