{-# LANGUAGE QuasiQuotes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Simple C runtime representation.
--
-- Most types use the same memory and scalar variable representation.
-- For those that do not (as of this writing, only `Float16`), we use
-- 'primStorageType' for the array element representation, and
-- 'primTypeToCType' for their scalar representation.  Use 'toStorage'
-- and 'fromStorage' to convert back and forth.
module Futhark.CodeGen.Backends.SimpleRep
  ( tupleField,
    funName,
    defaultMemBlockType,
    intTypeToCType,
    primTypeToCType,
    primStorageType,
    primAPIType,
    arrayName,
    opaqueName,
    isValidCName,
    escapeName,
    toStorage,
    fromStorage,
    cproduct,
    csum,
    allEqual,
    allTrue,
    scalarToPrim,

    -- * Primitive value operations
    cScalarDefs,

    -- * Storing/restoring values in byte sequences
    storageSize,
    storeValueHeader,
    loadValueHeader,
  )
where

import Control.Monad (void)
import Data.Char (isAlpha, isAlphaNum, isDigit)
import Data.Text qualified as T
import Data.Void (Void)
import Futhark.CodeGen.ImpCode
import Futhark.CodeGen.RTS.C (scalarF16H, scalarH)
import Futhark.Util (hashText, showText, zEncodeText)
import Language.C.Quote.C qualified as C
import Language.C.Syntax qualified as C
import Text.Megaparsec
import Text.Megaparsec.Char (space)

-- | The C type corresponding to a signed integer type.
intTypeToCType :: IntType -> C.Type
intTypeToCType :: IntType -> Type
intTypeToCType IntType
Int8 = [C.cty|typename int8_t|]
intTypeToCType IntType
Int16 = [C.cty|typename int16_t|]
intTypeToCType IntType
Int32 = [C.cty|typename int32_t|]
intTypeToCType IntType
Int64 = [C.cty|typename int64_t|]

-- | The C type corresponding to an unsigned integer type.
uintTypeToCType :: IntType -> C.Type
uintTypeToCType :: IntType -> Type
uintTypeToCType IntType
Int8 = [C.cty|typename uint8_t|]
uintTypeToCType IntType
Int16 = [C.cty|typename uint16_t|]
uintTypeToCType IntType
Int32 = [C.cty|typename uint32_t|]
uintTypeToCType IntType
Int64 = [C.cty|typename uint64_t|]

-- | The C type corresponding to a primitive type.  Integers are
-- assumed to be unsigned.
primTypeToCType :: PrimType -> C.Type
primTypeToCType :: PrimType -> Type
primTypeToCType (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primTypeToCType (FloatType FloatType
Float16) = [C.cty|typename f16|]
primTypeToCType (FloatType FloatType
Float32) = [C.cty|float|]
primTypeToCType (FloatType FloatType
Float64) = [C.cty|double|]
primTypeToCType PrimType
Bool = [C.cty|typename bool|]
primTypeToCType PrimType
Unit = [C.cty|typename bool|]

-- | The C storage type for arrays of this primitive type.
primStorageType :: PrimType -> C.Type
primStorageType :: PrimType -> Type
primStorageType (FloatType FloatType
Float16) = [C.cty|typename uint16_t|]
primStorageType PrimType
t = PrimType -> Type
primTypeToCType PrimType
t

-- | The C API corresponding to a primitive type.  Integers are
-- assumed to have the specified sign.
primAPIType :: Signedness -> PrimType -> C.Type
primAPIType :: Signedness -> PrimType -> Type
primAPIType Signedness
Unsigned (IntType IntType
t) = IntType -> Type
uintTypeToCType IntType
t
primAPIType Signedness
Signed (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primAPIType Signedness
_ PrimType
t = PrimType -> Type
primStorageType PrimType
t

-- | Convert from scalar to storage representation for the given type.
toStorage :: PrimType -> C.Exp -> C.Exp
toStorage :: PrimType -> Exp -> Exp
toStorage (FloatType FloatType
Float16) Exp
e = [C.cexp|futrts_to_bits16($exp:e)|]
toStorage PrimType
_ Exp
e = Exp
e

-- | Convert from storage to scalar representation for the given type.
fromStorage :: PrimType -> C.Exp -> C.Exp
fromStorage :: PrimType -> Exp -> Exp
fromStorage (FloatType FloatType
Float16) Exp
e = [C.cexp|futrts_from_bits16($exp:e)|]
fromStorage PrimType
_ Exp
e = Exp
e

-- | @tupleField i@ is the name of field number @i@ in a tuple.
tupleField :: Int -> String
tupleField :: Int -> String
tupleField Int
i = String
"v" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i

-- | @funName f@ is the name of the C function corresponding to
-- the Futhark function @f@.
funName :: Name -> T.Text
funName :: Name -> Text
funName = (Text
"futrts_" <>) (Text -> Text) -> (Name -> Text) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (Name -> Text) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Text
nameToText

-- | The type of memory blocks in the default memory space.
defaultMemBlockType :: C.Type
defaultMemBlockType :: Type
defaultMemBlockType = [C.cty|unsigned char*|]

-- | The name of exposed array type structs.
arrayName :: PrimType -> Signedness -> Int -> T.Text
arrayName :: PrimType -> Signedness -> Int -> Text
arrayName PrimType
pt Signedness
signed Int
rank =
  Bool -> PrimType -> Text
prettySigned (Signedness
signed Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Unsigned) PrimType
pt Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
rank Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"d"

-- | Is this name a valid C identifier?  If not, it should be escaped
-- before being emitted into C.
isValidCName :: T.Text -> Bool
isValidCName :: Text -> Bool
isValidCName = Bool -> ((Char, Text) -> Bool) -> Maybe (Char, Text) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Char, Text) -> Bool
check (Maybe (Char, Text) -> Bool)
-> (Text -> Maybe (Char, Text)) -> Text -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe (Char, Text)
T.uncons
  where
    check :: (Char, Text) -> Bool
check (Char
c, Text
cs) = Char -> Bool
isAlpha Char
c Bool -> Bool -> Bool
&& (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
constituent Text
cs
    constituent :: Char -> Bool
constituent Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'

-- | If the provided text is a valid C identifier, then return it
-- verbatim.  Otherwise, escape it such that it becomes valid.
escapeName :: T.Text -> T.Text
escapeName :: Text -> Text
escapeName Text
v
  | Text -> Bool
isValidCName Text
v = Text
v
  | Bool
otherwise = Text -> Text
zEncodeText Text
v

-- | Valid C identifier name?
valid :: T.Text -> Bool
valid :: Text -> Bool
valid Text
s =
  HasCallStack => Text -> Char
Text -> Char
T.head Text
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'_'
    Bool -> Bool -> Bool
&& Bool -> Bool
not (Char -> Bool
isDigit (Char -> Bool) -> Char -> Bool
forall a b. (a -> b) -> a -> b
$ HasCallStack => Text -> Char
Text -> Char
T.head Text
s)
    Bool -> Bool -> Bool
&& (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
ok Text
s
  where
    ok :: Char -> Bool
ok Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'

-- | Find a nice C type name name for the Futhark type. This solely
-- serves to make the generated header file easy to read, and we can
-- always fall back on an ugly hash.
findPrettyName :: T.Text -> Either String T.Text
findPrettyName :: Text -> Either String Text
findPrettyName =
  (ParseErrorBundle Text Void -> Either String Text)
-> (Text -> Either String Text)
-> Either (ParseErrorBundle Text Void) Text
-> Either String Text
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> Either String Text
forall a b. a -> Either a b
Left (String -> Either String Text)
-> (ParseErrorBundle Text Void -> String)
-> ParseErrorBundle Text Void
-> Either String Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParseErrorBundle Text Void -> String
forall s e.
(VisualStream s, TraversableStream s, ShowErrorComponent e) =>
ParseErrorBundle s e -> String
errorBundlePretty) Text -> Either String Text
forall a b. b -> Either a b
Right (Either (ParseErrorBundle Text Void) Text -> Either String Text)
-> (Text -> Either (ParseErrorBundle Text Void) Text)
-> Text
-> Either String Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Parsec Void Text Text
-> String -> Text -> Either (ParseErrorBundle Text Void) Text
forall e s a.
Parsec e s a -> String -> s -> Either (ParseErrorBundle s e) a
parse (Parsec Void Text Text
p Parsec Void Text Text
-> ParsecT Void Text Identity () -> Parsec Void Text Text
forall a b.
ParsecT Void Text Identity a
-> ParsecT Void Text Identity b -> ParsecT Void Text Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ParsecT Void Text Identity ()
forall e s (m :: * -> *). MonadParsec e s m => m ()
eof) String
"type name"
  where
    p :: Parsec Void T.Text T.Text
    p :: Parsec Void Text Text
p = [Parsec Void Text Text] -> Parsec Void Text Text
forall (f :: * -> *) (m :: * -> *) a.
(Foldable f, Alternative m) =>
f (m a) -> m a
choice [Parsec Void Text Text
pArr, Parsec Void Text Text
pTup, Parsec Void Text Text
pAtom]
    pArr :: Parsec Void Text Text
pArr = do
      [Text]
dims <- Parsec Void Text Text -> ParsecT Void Text Identity [Text]
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
some Parsec Void Text Text
"[]"
      ((Text
"arr" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showText ([Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
dims) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"d_") <>) (Text -> Text) -> Parsec Void Text Text -> Parsec Void Text Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parsec Void Text Text
p
    pTup :: Parsec Void Text Text
pTup = Parsec Void Text Text
-> Parsec Void Text Text
-> Parsec Void Text Text
-> Parsec Void Text Text
forall (m :: * -> *) open close a.
Applicative m =>
m open -> m close -> m a -> m a
between Parsec Void Text Text
"(" Parsec Void Text Text
")" (Parsec Void Text Text -> Parsec Void Text Text)
-> Parsec Void Text Text -> Parsec Void Text Text
forall a b. (a -> b) -> a -> b
$ do
      [Text]
ts <- Parsec Void Text Text
p Parsec Void Text Text
-> ParsecT Void Text Identity ()
-> ParsecT Void Text Identity [Text]
forall (m :: * -> *) a sep. MonadPlus m => m a -> m sep -> m [a]
`sepBy` ParsecT Void Text Identity ()
pComma
      Text -> Parsec Void Text Text
forall a. a -> ParsecT Void Text Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> Parsec Void Text Text) -> Text -> Parsec Void Text Text
forall a b. (a -> b) -> a -> b
$ Text
"tup" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showText ([Text] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
ts) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> [Text] -> Text
T.intercalate Text
"_" [Text]
ts
    pAtom :: Parsec Void Text Text
pAtom = String -> Text
T.pack (String -> Text)
-> ParsecT Void Text Identity String -> Parsec Void Text Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ParsecT Void Text Identity Char
-> ParsecT Void Text Identity String
forall (m :: * -> *) a. MonadPlus m => m a -> m [a]
some ((Token Text -> Bool) -> ParsecT Void Text Identity (Token Text)
forall e s (m :: * -> *).
MonadParsec e s m =>
(Token s -> Bool) -> m (Token s)
satisfy (Char -> String -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (String
"[]{}()," :: String)))
    pComma :: ParsecT Void Text Identity ()
pComma = Parsec Void Text Text -> ParsecT Void Text Identity ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Parsec Void Text Text -> ParsecT Void Text Identity ())
-> Parsec Void Text Text -> ParsecT Void Text Identity ()
forall a b. (a -> b) -> a -> b
$ Parsec Void Text Text
"," Parsec Void Text Text
-> ParsecT Void Text Identity () -> Parsec Void Text Text
forall a b.
ParsecT Void Text Identity a
-> ParsecT Void Text Identity b -> ParsecT Void Text Identity a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* ParsecT Void Text Identity ()
forall e s (m :: * -> *).
(MonadParsec e s m, Token s ~ Char) =>
m ()
space

-- | The name of exposed opaque types.
opaqueName :: Name -> T.Text
opaqueName :: Name -> Text
opaqueName Name
"()" = Text
"opaque_unit" -- Hopefully this ad-hoc convenience won't bite us.
opaqueName Name
s
  | Right Text
v <- Text -> Either String Text
findPrettyName Text
s',
    Text -> Bool
valid Text
v =
      Text
"opaque_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
v
  | Text -> Bool
valid Text
s' = Text
"opaque_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
s'
  where
    s' :: Text
s' = Name -> Text
nameToText Name
s
opaqueName Name
s = Text
"opaque_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
hashText (Name -> Text
nameToText Name
s)

-- | The 'PrimType' (and sign) corresponding to a human-readable scalar
-- type name (e.g. @f64@).  Beware: partial!
scalarToPrim :: T.Text -> (Signedness, PrimType)
scalarToPrim :: Text -> (Signedness, PrimType)
scalarToPrim Text
"bool" = (Signedness
Signed, PrimType
Bool)
scalarToPrim Text
"i8" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int8)
scalarToPrim Text
"i16" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int16)
scalarToPrim Text
"i32" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int32)
scalarToPrim Text
"i64" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int64)
scalarToPrim Text
"u8" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int8)
scalarToPrim Text
"u16" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int16)
scalarToPrim Text
"u32" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int32)
scalarToPrim Text
"u64" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int64)
scalarToPrim Text
"f16" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float16)
scalarToPrim Text
"f32" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float32)
scalarToPrim Text
"f64" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float64)
scalarToPrim Text
tname = String -> (Signedness, PrimType)
forall a. HasCallStack => String -> a
error (String -> (Signedness, PrimType))
-> String -> (Signedness, PrimType)
forall a b. (a -> b) -> a -> b
$ String
"scalarToPrim: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
tname

-- | Return an expression multiplying together the given expressions.
-- If an empty list is given, the expression @1@ is returned.
cproduct :: [C.Exp] -> C.Exp
cproduct :: [Exp] -> Exp
cproduct [] = [C.cexp|1|]
cproduct (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x * $exp:y|]

-- | Return an expression summing the given expressions.
-- If an empty list is given, the expression @0@ is returned.
csum :: [C.Exp] -> C.Exp
csum :: [Exp] -> Exp
csum [] = [C.cexp|0|]
csum (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x + $exp:y|]

-- | An expression that is true if these are also all true.
allTrue :: [C.Exp] -> C.Exp
allTrue :: [Exp] -> Exp
allTrue [] = [C.cexp|true|]
allTrue [Exp
x] = Exp
x
allTrue (Exp
x : [Exp]
xs) = [C.cexp|$exp:x && $exp:(allTrue xs)|]

-- | An expression that is true if these expressions are all equal by
-- @==@.
allEqual :: [C.Exp] -> C.Exp
allEqual :: [Exp] -> Exp
allEqual [Exp
x, Exp
y] = [C.cexp|$exp:x == $exp:y|]
allEqual (Exp
x : Exp
y : [Exp]
xs) = [C.cexp|$exp:x == $exp:y && $exp:(allEqual(y:xs))|]
allEqual [Exp]
_ = [C.cexp|true|]

instance C.ToIdent Name where
  toIdent :: Name -> SrcLoc -> Id
toIdent = Text -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Text -> SrcLoc -> Id) -> (Name -> Text) -> Name -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (Name -> Text) -> Name -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Text
nameToText

-- Orphan!
instance C.ToIdent T.Text where
  toIdent :: Text -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (Text -> String) -> Text -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack

instance C.ToIdent VName where
  toIdent :: VName -> SrcLoc -> Id
toIdent = Text -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Text -> SrcLoc -> Id) -> (VName -> Text) -> VName -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (VName -> Text) -> VName -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Text
forall a. Pretty a => a -> Text
prettyText

instance C.ToExp VName where
  toExp :: VName -> SrcLoc -> Exp
toExp VName
v SrcLoc
_ = [C.cexp|$id:v|]

instance C.ToExp IntValue where
  toExp :: IntValue -> SrcLoc -> Exp
toExp (Int8Value Int8
k) SrcLoc
_ = [C.cexp|(typename int8_t)$int:k|]
  toExp (Int16Value Int16
k) SrcLoc
_ = [C.cexp|(typename int16_t)$int:k|]
  toExp (Int32Value Int32
k) SrcLoc
_ = [C.cexp|$int:k|]
  toExp (Int64Value Int64
k) SrcLoc
_ = [C.cexp|(typename int64_t)$int:k|]

instance C.ToExp FloatValue where
  toExp :: FloatValue -> SrcLoc -> Exp
toExp (Float16Value Half
x) SrcLoc
_
    | Half -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Half
x =
        if Half
x Half -> Half -> Bool
forall a. Ord a => a -> a -> Bool
> Half
0 then [C.cexp|(typename f16)INFINITY|] else [C.cexp|(typename f16)-INFINITY|]
    | Half -> Bool
forall a. RealFloat a => a -> Bool
isNaN Half
x =
        [C.cexp|(typename f16)NAN|]
    | Bool
otherwise =
        [C.cexp|(typename f16)$float:(fromRational (toRational x))|]
  toExp (Float32Value Float
x) SrcLoc
_
    | Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Float
x =
        if Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
> Float
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
    | Float -> Bool
forall a. RealFloat a => a -> Bool
isNaN Float
x =
        [C.cexp|NAN|]
    | Bool
otherwise =
        [C.cexp|$float:x|]
  toExp (Float64Value Double
x) SrcLoc
_
    | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x =
        if Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
    | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x =
        [C.cexp|NAN|]
    | Bool
otherwise =
        [C.cexp|$double:x|]

instance C.ToExp PrimValue where
  toExp :: PrimValue -> SrcLoc -> Exp
toExp (IntValue IntValue
v) = IntValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp IntValue
v
  toExp (FloatValue FloatValue
v) = FloatValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp FloatValue
v
  toExp (BoolValue Bool
True) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
  toExp (BoolValue Bool
False) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
  toExp PrimValue
UnitValue = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)

instance C.ToExp SubExp where
  toExp :: SubExp -> SrcLoc -> Exp
toExp (Var VName
v) = VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
v
  toExp (Constant PrimValue
c) = PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
c

-- | Implementations of scalar operations.
cScalarDefs :: T.Text
cScalarDefs :: Text
cScalarDefs = Text
scalarH Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
scalarF16H

-- | @storageSize pt rank shape@ produces an expression giving size
-- taken when storing this value in the binary value format.  It is
-- assumed that the @shape@ is an array with @rank@ dimensions.
storageSize :: PrimType -> Int -> C.Exp -> C.Exp
storageSize :: PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape =
  [C.cexp|$int:header_size +
          $int:rank * sizeof(typename int64_t) +
          $exp:(cproduct dims) * sizeof($ty:(primStorageType pt))|]
  where
    header_size :: Int
    header_size :: Int
header_size = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4 -- 'b' <version> <num_dims> <type>
    dims :: [Exp]
dims = [[C.cexp|$exp:shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]

typeStr :: Signedness -> PrimType -> String
typeStr :: Signedness -> PrimType -> String
typeStr Signedness
sign PrimType
pt =
  case (Signedness
sign, PrimType
pt) of
    (Signedness
_, PrimType
Bool) -> String
"bool"
    (Signedness
_, PrimType
Unit) -> String
"bool"
    (Signedness
_, FloatType FloatType
Float16) -> String
" f16"
    (Signedness
_, FloatType FloatType
Float32) -> String
" f32"
    (Signedness
_, FloatType FloatType
Float64) -> String
" f64"
    (Signedness
Signed, IntType IntType
Int8) -> String
"  i8"
    (Signedness
Signed, IntType IntType
Int16) -> String
" i16"
    (Signedness
Signed, IntType IntType
Int32) -> String
" i32"
    (Signedness
Signed, IntType IntType
Int64) -> String
" i64"
    (Signedness
Unsigned, IntType IntType
Int8) -> String
"  u8"
    (Signedness
Unsigned, IntType IntType
Int16) -> String
" u16"
    (Signedness
Unsigned, IntType IntType
Int32) -> String
" u32"
    (Signedness
Unsigned, IntType IntType
Int64) -> String
" u64"

-- | Produce code for storing the header (everything besides the
-- actual payload) for a value of this type.
storeValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
storeValueHeader :: Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape Exp
dest =
  [C.cstms|
          *$exp:dest++ = 'b';
          *$exp:dest++ = 2;
          *$exp:dest++ = $int:rank;
          memcpy($exp:dest, $string:(typeStr sign pt), 4);
          $exp:dest += 4;
          $stms:copy_shape
          |]
  where
    copy_shape :: [Stm]
copy_shape
      | Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
      | Bool
otherwise =
          [C.cstms|
                memcpy($exp:dest, $exp:shape, $int:rank*sizeof(typename int64_t));
                $exp:dest += $int:rank*sizeof(typename int64_t);|]

-- | Produce code for loading the header (everything besides the
-- actual payload) for a value of this type.
loadValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
loadValueHeader :: Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape Exp
src =
  [C.cstms|
     err |= (*$exp:src++ != 'b');
     err |= (*$exp:src++ != 2);
     err |= (*$exp:src++ != $exp:rank);
     err |= (memcmp($exp:src, $string:(typeStr sign pt), 4) != 0);
     $exp:src += 4;
     if (err == 0) {
       $stms:load_shape
       $exp:src += $int:rank*sizeof(typename int64_t);
     }|]
  where
    load_shape :: [Stm]
load_shape
      | Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = []
      | Bool
otherwise = [C.cstms|memcpy($exp:shape, src, $int:rank*sizeof(typename int64_t));|]