{-# LANGUAGE QuasiQuotes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Simple C runtime representation.
--
-- Most types use the same memory and scalar variable representation.
-- For those that do not (as of this writing, only `Float16`), we use
-- 'primStorageType' for the array element representation, and
-- 'primTypeToCType' for their scalar representation.  Use 'toStorage'
-- and 'fromStorage' to convert back and forth.
module Futhark.CodeGen.Backends.SimpleRep
  ( tupleField,
    funName,
    defaultMemBlockType,
    intTypeToCType,
    primTypeToCType,
    primStorageType,
    primAPIType,
    arrayName,
    opaqueName,
    toStorage,
    fromStorage,
    cproduct,
    csum,
    scalarToPrim,

    -- * Primitive value operations
    cScalarDefs,

    -- * Storing/restoring values in byte sequences
    storageSize,
    storeValueHeader,
    loadValueHeader,
  )
where

import Data.Bits (shiftR, xor)
import Data.Char (isAlphaNum, isDigit, ord)
import Data.Text qualified as T
import Futhark.CodeGen.ImpCode
import Futhark.CodeGen.RTS.C (scalarF16H, scalarH)
import Futhark.Util (zEncodeString)
import Language.C.Quote.C qualified as C
import Language.C.Syntax qualified as C
import Text.Printf

-- | The C type corresponding to a signed integer type.
intTypeToCType :: IntType -> C.Type
intTypeToCType :: IntType -> Type
intTypeToCType IntType
Int8 = [C.cty|typename int8_t|]
intTypeToCType IntType
Int16 = [C.cty|typename int16_t|]
intTypeToCType IntType
Int32 = [C.cty|typename int32_t|]
intTypeToCType IntType
Int64 = [C.cty|typename int64_t|]

-- | The C type corresponding to an unsigned integer type.
uintTypeToCType :: IntType -> C.Type
uintTypeToCType :: IntType -> Type
uintTypeToCType IntType
Int8 = [C.cty|typename uint8_t|]
uintTypeToCType IntType
Int16 = [C.cty|typename uint16_t|]
uintTypeToCType IntType
Int32 = [C.cty|typename uint32_t|]
uintTypeToCType IntType
Int64 = [C.cty|typename uint64_t|]

-- | The C type corresponding to a primitive type.  Integers are
-- assumed to be unsigned.
primTypeToCType :: PrimType -> C.Type
primTypeToCType :: PrimType -> Type
primTypeToCType (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primTypeToCType (FloatType FloatType
Float16) = [C.cty|typename f16|]
primTypeToCType (FloatType FloatType
Float32) = [C.cty|float|]
primTypeToCType (FloatType FloatType
Float64) = [C.cty|double|]
primTypeToCType PrimType
Bool = [C.cty|typename bool|]
primTypeToCType PrimType
Unit = [C.cty|typename bool|]

-- | The C storage type for arrays of this primitive type.
primStorageType :: PrimType -> C.Type
primStorageType :: PrimType -> Type
primStorageType (FloatType FloatType
Float16) = [C.cty|typename uint16_t|]
primStorageType PrimType
t = PrimType -> Type
primTypeToCType PrimType
t

-- | The C API corresponding to a primitive type.  Integers are
-- assumed to have the specified sign.
primAPIType :: Signedness -> PrimType -> C.Type
primAPIType :: Signedness -> PrimType -> Type
primAPIType Signedness
Unsigned (IntType IntType
t) = IntType -> Type
uintTypeToCType IntType
t
primAPIType Signedness
Signed (IntType IntType
t) = IntType -> Type
intTypeToCType IntType
t
primAPIType Signedness
_ PrimType
t = PrimType -> Type
primStorageType PrimType
t

-- | Convert from scalar to storage representation for the given type.
toStorage :: PrimType -> C.Exp -> C.Exp
toStorage :: PrimType -> Exp -> Exp
toStorage (FloatType FloatType
Float16) Exp
e = [C.cexp|futrts_to_bits16($exp:e)|]
toStorage PrimType
_ Exp
e = Exp
e

-- | Convert from storage to scalar representation for the given type.
fromStorage :: PrimType -> C.Exp -> C.Exp
fromStorage :: PrimType -> Exp -> Exp
fromStorage (FloatType FloatType
Float16) Exp
e = [C.cexp|futrts_from_bits16($exp:e)|]
fromStorage PrimType
_ Exp
e = Exp
e

-- | @tupleField i@ is the name of field number @i@ in a tuple.
tupleField :: Int -> String
tupleField :: Int -> String
tupleField Int
i = String
"v" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i

-- | @funName f@ is the name of the C function corresponding to
-- the Futhark function @f@.
funName :: Name -> String
funName :: Name -> String
funName = (String
"futrts_" ++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString

-- | The type of memory blocks in the default memory space.
defaultMemBlockType :: C.Type
defaultMemBlockType :: Type
defaultMemBlockType = [C.cty|unsigned char*|]

-- | The name of exposed array type structs.
arrayName :: PrimType -> Signedness -> Int -> String
arrayName :: PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank =
  Bool -> PrimType -> String
prettySigned (Signedness
signed forall a. Eq a => a -> a -> Bool
== Signedness
Unsigned) PrimType
pt forall a. [a] -> [a] -> [a]
++ String
"_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
rank forall a. [a] -> [a] -> [a]
++ String
"d"

-- | The name of exposed opaque types.
opaqueName :: String -> String
opaqueName :: String -> String
opaqueName String
"()" = String
"opaque_unit" -- Hopefully this ad-hoc convenience won't bite us.
opaqueName String
s
  | Bool
valid = String
"opaque_" forall a. [a] -> [a] -> [a]
++ String
s
  where
    valid :: Bool
valid =
      forall a. [a] -> a
head String
s forall a. Eq a => a -> a -> Bool
/= Char
'_'
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Char -> Bool
isDigit forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head String
s)
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
ok String
s
    ok :: Char -> Bool
ok Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c forall a. Eq a => a -> a -> Bool
== Char
'_'
opaqueName String
s = String
"opaque_" forall a. [a] -> [a] -> [a]
++ [Int] -> String
hash (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Bits a => a -> a -> a
xor [Int
0 ..] forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Char -> Int
ord String
s)
  where
    -- FIXME: a stupid hash algorithm; may have collisions.
    hash :: [Int] -> String
hash =
      forall r. PrintfType r => String -> r
printf String
"%x"
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Bits a => a -> a -> a
xor Word32
0
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map
          ( Word32 -> Word32
iter
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
              forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral
          )
    iter :: Word32 -> Word32
iter Word32
x = ((Word32
x :: Word32) forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
`xor` Word32
x

-- | The 'PrimType' (and sign) correspond to a human-readable scalar
-- type name (e.g. @f64@).  Beware: partial!
scalarToPrim :: T.Text -> (Signedness, PrimType)
scalarToPrim :: Text -> (Signedness, PrimType)
scalarToPrim Text
"bool" = (Signedness
Signed, PrimType
Bool)
scalarToPrim Text
"i8" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int8)
scalarToPrim Text
"i16" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int16)
scalarToPrim Text
"i32" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int32)
scalarToPrim Text
"i64" = (Signedness
Signed, IntType -> PrimType
IntType IntType
Int64)
scalarToPrim Text
"u8" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int8)
scalarToPrim Text
"u16" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int16)
scalarToPrim Text
"u32" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int32)
scalarToPrim Text
"u64" = (Signedness
Unsigned, IntType -> PrimType
IntType IntType
Int64)
scalarToPrim Text
"f16" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float16)
scalarToPrim Text
"f32" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float32)
scalarToPrim Text
"f64" = (Signedness
Signed, FloatType -> PrimType
FloatType FloatType
Float64)
scalarToPrim Text
tname = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"scalarToPrim: " forall a. Semigroup a => a -> a -> a
<> Text -> String
T.unpack Text
tname

-- | Return an expression multiplying together the given expressions.
-- If an empty list is given, the expression @1@ is returned.
cproduct :: [C.Exp] -> C.Exp
cproduct :: [Exp] -> Exp
cproduct [] = [C.cexp|1|]
cproduct (Exp
e : [Exp]
es) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x * $exp:y|]

-- | Return an expression summing the given expressions.
-- If an empty list is given, the expression @0@ is returned.
csum :: [C.Exp] -> C.Exp
csum :: [Exp] -> Exp
csum [] = [C.cexp|0|]
csum (Exp
e : [Exp]
es) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a} {a}. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x + $exp:y|]

instance C.ToIdent Name where
  toIdent :: Name -> SrcLoc -> Id
toIdent = forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString

-- Orphan!
instance C.ToIdent T.Text where
  toIdent :: Text -> SrcLoc -> Id
toIdent = forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack

instance C.ToIdent VName where
  toIdent :: VName -> SrcLoc -> Id
toIdent = forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> String
prettyString

instance C.ToExp VName where
  toExp :: VName -> SrcLoc -> Exp
toExp VName
v SrcLoc
_ = [C.cexp|$id:v|]

instance C.ToExp IntValue where
  toExp :: IntValue -> SrcLoc -> Exp
toExp (Int8Value Int8
k) SrcLoc
_ = [C.cexp|(typename int8_t)$int:k|]
  toExp (Int16Value Int16
k) SrcLoc
_ = [C.cexp|(typename int16_t)$int:k|]
  toExp (Int32Value Int32
k) SrcLoc
_ = [C.cexp|$int:k|]
  toExp (Int64Value Int64
k) SrcLoc
_ = [C.cexp|(typename int64_t)$int:k|]

instance C.ToExp FloatValue where
  toExp :: FloatValue -> SrcLoc -> Exp
toExp (Float16Value Half
x) SrcLoc
_
    | forall a. RealFloat a => a -> Bool
isInfinite Half
x =
        if Half
x forall a. Ord a => a -> a -> Bool
> Half
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
    | forall a. RealFloat a => a -> Bool
isNaN Half
x =
        [C.cexp|NAN|]
    | Bool
otherwise =
        [C.cexp|$float:(fromRational (toRational x))|]
  toExp (Float32Value Float
x) SrcLoc
_
    | forall a. RealFloat a => a -> Bool
isInfinite Float
x =
        if Float
x forall a. Ord a => a -> a -> Bool
> Float
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
    | forall a. RealFloat a => a -> Bool
isNaN Float
x =
        [C.cexp|NAN|]
    | Bool
otherwise =
        [C.cexp|$float:x|]
  toExp (Float64Value Double
x) SrcLoc
_
    | forall a. RealFloat a => a -> Bool
isInfinite Double
x =
        if Double
x forall a. Ord a => a -> a -> Bool
> Double
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
    | forall a. RealFloat a => a -> Bool
isNaN Double
x =
        [C.cexp|NAN|]
    | Bool
otherwise =
        [C.cexp|$double:x|]

instance C.ToExp PrimValue where
  toExp :: PrimValue -> SrcLoc -> Exp
toExp (IntValue IntValue
v) = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp IntValue
v
  toExp (FloatValue FloatValue
v) = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp FloatValue
v
  toExp (BoolValue Bool
True) = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
  toExp (BoolValue Bool
False) = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
  toExp PrimValue
UnitValue = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)

instance C.ToExp SubExp where
  toExp :: SubExp -> SrcLoc -> Exp
toExp (Var VName
v) = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
v
  toExp (Constant PrimValue
c) = forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
c

-- | Implementations of scalar operations.
cScalarDefs :: T.Text
cScalarDefs :: Text
cScalarDefs = Text
scalarH forall a. Semigroup a => a -> a -> a
<> Text
scalarF16H

-- | @storageSize pt rank shape@ produces an expression giving size
-- taken when storing this value in the binary value format.  It is
-- assumed that the @shape@ is an array with @rank@ dimensions.
storageSize :: PrimType -> Int -> C.Exp -> C.Exp
storageSize :: PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape =
  [C.cexp|$int:header_size +
          $int:rank * sizeof(typename int64_t) +
          $exp:(cproduct dims) * $int:pt_size|]
  where
    header_size, pt_size :: Int
    header_size :: Int
header_size = Int
1 forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
4 -- 'b' <version> <num_dims> <type>
    pt_size :: Int
pt_size = forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    dims :: [Exp]
dims = [[C.cexp|$exp:shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank forall a. Num a => a -> a -> a
- Int
1]]

typeStr :: Signedness -> PrimType -> String
typeStr :: Signedness -> PrimType -> String
typeStr Signedness
sign PrimType
pt =
  case (Signedness
sign, PrimType
pt) of
    (Signedness
_, PrimType
Bool) -> String
"bool"
    (Signedness
_, PrimType
Unit) -> String
"bool"
    (Signedness
_, FloatType FloatType
Float16) -> String
" f16"
    (Signedness
_, FloatType FloatType
Float32) -> String
" f32"
    (Signedness
_, FloatType FloatType
Float64) -> String
" f64"
    (Signedness
Signed, IntType IntType
Int8) -> String
"  i8"
    (Signedness
Signed, IntType IntType
Int16) -> String
" i16"
    (Signedness
Signed, IntType IntType
Int32) -> String
" i32"
    (Signedness
Signed, IntType IntType
Int64) -> String
" i64"
    (Signedness
Unsigned, IntType IntType
Int8) -> String
"  u8"
    (Signedness
Unsigned, IntType IntType
Int16) -> String
" u16"
    (Signedness
Unsigned, IntType IntType
Int32) -> String
" u32"
    (Signedness
Unsigned, IntType IntType
Int64) -> String
" u64"

-- | Produce code for storing the header (everything besides the
-- actual payload) for a value of this type.
storeValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
storeValueHeader :: Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape Exp
dest =
  [C.cstms|
          *$exp:dest++ = 'b';
          *$exp:dest++ = 2;
          *$exp:dest++ = $int:rank;
          memcpy($exp:dest, $string:(typeStr sign pt), 4);
          $exp:dest += 4;
          $stms:copy_shape
          |]
  where
    copy_shape :: [Stm]
copy_shape
      | Int
rank forall a. Eq a => a -> a -> Bool
== Int
0 = []
      | Bool
otherwise =
          [C.cstms|
                memcpy($exp:dest, $exp:shape, $int:rank*sizeof(typename int64_t));
                $exp:dest += $int:rank*sizeof(typename int64_t);|]

-- | Produce code for loading the header (everything besides the
-- actual payload) for a value of this type.
loadValueHeader :: Signedness -> PrimType -> Int -> C.Exp -> C.Exp -> [C.Stm]
loadValueHeader :: Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape Exp
src =
  [C.cstms|
     err |= (*$exp:src++ != 'b');
     err |= (*$exp:src++ != 2);
     err |= (*$exp:src++ != $exp:rank);
     err |= (memcmp($exp:src, $string:(typeStr sign pt), 4) != 0);
     $exp:src += 4;
     if (err == 0) {
       $stms:load_shape
       $exp:src += $int:rank*sizeof(typename int64_t);
     }|]
  where
    load_shape :: [Stm]
load_shape
      | Int
rank forall a. Eq a => a -> a -> Bool
== Int
0 = []
      | Bool
otherwise = [C.cstms|memcpy($exp:shape, src, $int:rank*sizeof(typename int64_t));|]