{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Type (
Half(..), Float, Double, Char, Bool(..),
module Data.Int,
module Data.Word,
module Foreign.C.Types,
module Data.Array.Accelerate.Type
) where
import Data.Orphans ()
import Data.Bits
import Data.Int
import Data.Type.Equality
import Data.Typeable
import Data.Word
import GHC.TypeLits
import Language.Haskell.TH
import Numeric.Half
import Text.Printf
import Foreign.Storable
import Foreign.C.Types
(CChar, CSChar, CUChar, CShort, CUShort, CInt, CUInt, CLong, CULong, CLLong, CULLong, CFloat, CDouble)
data IntegralDict a where
IntegralDict :: ( Bounded a, Eq a, Ord a, Show a
, Bits a, FiniteBits a, Integral a, Num a, Real a, Storable a )
=> IntegralDict a
data FloatingDict a where
FloatingDict :: ( Eq a, Ord a, Show a
, Floating a, Fractional a, Num a, Real a, RealFrac a
, RealFloat a, Storable a )
=> FloatingDict a
data NonNumDict a where
NonNumDict :: ( Bounded a, Eq a, Ord a, Show a, Storable a )
=> NonNumDict a
data IntegralType a where
TypeInt :: IntegralDict Int -> IntegralType Int
TypeInt8 :: IntegralDict Int8 -> IntegralType Int8
TypeInt16 :: IntegralDict Int16 -> IntegralType Int16
TypeInt32 :: IntegralDict Int32 -> IntegralType Int32
TypeInt64 :: IntegralDict Int64 -> IntegralType Int64
TypeWord :: IntegralDict Word -> IntegralType Word
TypeWord8 :: IntegralDict Word8 -> IntegralType Word8
TypeWord16 :: IntegralDict Word16 -> IntegralType Word16
TypeWord32 :: IntegralDict Word32 -> IntegralType Word32
TypeWord64 :: IntegralDict Word64 -> IntegralType Word64
TypeCShort :: IntegralDict CShort -> IntegralType CShort
TypeCUShort :: IntegralDict CUShort -> IntegralType CUShort
TypeCInt :: IntegralDict CInt -> IntegralType CInt
TypeCUInt :: IntegralDict CUInt -> IntegralType CUInt
TypeCLong :: IntegralDict CLong -> IntegralType CLong
TypeCULong :: IntegralDict CULong -> IntegralType CULong
TypeCLLong :: IntegralDict CLLong -> IntegralType CLLong
TypeCULLong :: IntegralDict CULLong -> IntegralType CULLong
data FloatingType a where
TypeHalf :: FloatingDict Half -> FloatingType Half
TypeFloat :: FloatingDict Float -> FloatingType Float
TypeDouble :: FloatingDict Double -> FloatingType Double
TypeCFloat :: FloatingDict CFloat -> FloatingType CFloat
TypeCDouble :: FloatingDict CDouble -> FloatingType CDouble
data NonNumType a where
TypeBool :: NonNumDict Bool -> NonNumType Bool
TypeChar :: NonNumDict Char -> NonNumType Char
TypeCChar :: NonNumDict CChar -> NonNumType CChar
TypeCSChar :: NonNumDict CSChar -> NonNumType CSChar
TypeCUChar :: NonNumDict CUChar -> NonNumType CUChar
data NumType a where
IntegralNumType :: IntegralType a -> NumType a
FloatingNumType :: FloatingType a -> NumType a
data BoundedType a where
IntegralBoundedType :: IntegralType a -> BoundedType a
NonNumBoundedType :: NonNumType a -> BoundedType a
data ScalarType a where
SingleScalarType :: SingleType a -> ScalarType a
VectorScalarType :: VectorType (v a) -> ScalarType (v a)
data SingleType a where
NumSingleType :: NumType a -> SingleType a
NonNumSingleType :: NonNumType a -> SingleType a
data VectorType v where
Vector2Type :: SingleType a -> VectorType (V2 a)
Vector3Type :: SingleType a -> VectorType (V3 a)
Vector4Type :: SingleType a -> VectorType (V4 a)
Vector8Type :: SingleType a -> VectorType (V8 a)
Vector16Type :: SingleType a -> VectorType (V16 a)
instance Show (IntegralType a) where
show (TypeInt _) = "Int"
show (TypeInt8 _) = "Int8"
show (TypeInt16 _) = "Int16"
show (TypeInt32 _) = "Int32"
show (TypeInt64 _) = "Int64"
show (TypeWord _) = "Word"
show (TypeWord8 _) = "Word8"
show (TypeWord16 _) = "Word16"
show (TypeWord32 _) = "Word32"
show (TypeWord64 _) = "Word64"
show (TypeCShort _) = "CShort"
show (TypeCUShort _) = "CUShort"
show (TypeCInt _) = "CInt"
show (TypeCUInt _) = "CUInt"
show (TypeCLong _) = "CLong"
show (TypeCULong _) = "CULong"
show (TypeCLLong _) = "CLLong"
show (TypeCULLong _) = "CULLong"
instance Show (FloatingType a) where
show (TypeHalf _) = "Half"
show (TypeFloat _) = "Float"
show (TypeDouble _) = "Double"
show (TypeCFloat _) = "CFloat"
show (TypeCDouble _) = "CDouble"
instance Show (NonNumType a) where
show (TypeBool _) = "Bool"
show (TypeChar _) = "Char"
show (TypeCChar _) = "CChar"
show (TypeCSChar _) = "CSChar"
show (TypeCUChar _) = "CUChar"
instance Show (NumType a) where
show (IntegralNumType ty) = show ty
show (FloatingNumType ty) = show ty
instance Show (BoundedType a) where
show (IntegralBoundedType ty) = show ty
show (NonNumBoundedType ty) = show ty
instance Show (SingleType a) where
show (NumSingleType ty) = show ty
show (NonNumSingleType ty) = show ty
instance Show (VectorType a) where
show (Vector2Type t) = printf "<2 x %s>" (show t)
show (Vector3Type t) = printf "<3 x %s>" (show t)
show (Vector4Type t) = printf "<4 x %s>" (show t)
show (Vector8Type t) = printf "<8 x %s>" (show t)
show (Vector16Type t) = printf "<16 x %s>" (show t)
instance Show (ScalarType a) where
show (SingleScalarType ty) = show ty
show (VectorScalarType ty) = show ty
class (IsSingle a, IsNum a, IsBounded a) => IsIntegral a where
integralType :: IntegralType a
class (Floating a, IsSingle a, IsNum a) => IsFloating a where
floatingType :: FloatingType a
class IsNonNum a where
nonNumType :: NonNumType a
class (Num a, IsSingle a) => IsNum a where
numType :: NumType a
class IsBounded a where
boundedType :: BoundedType a
class IsScalar a => IsSingle a where
singleType :: SingleType a
class Typeable a => IsScalar a where
scalarType :: ScalarType a
integralDict :: IntegralType a -> IntegralDict a
integralDict (TypeInt dict) = dict
integralDict (TypeInt8 dict) = dict
integralDict (TypeInt16 dict) = dict
integralDict (TypeInt32 dict) = dict
integralDict (TypeInt64 dict) = dict
integralDict (TypeWord dict) = dict
integralDict (TypeWord8 dict) = dict
integralDict (TypeWord16 dict) = dict
integralDict (TypeWord32 dict) = dict
integralDict (TypeWord64 dict) = dict
integralDict (TypeCShort dict) = dict
integralDict (TypeCUShort dict) = dict
integralDict (TypeCInt dict) = dict
integralDict (TypeCUInt dict) = dict
integralDict (TypeCLong dict) = dict
integralDict (TypeCULong dict) = dict
integralDict (TypeCLLong dict) = dict
integralDict (TypeCULLong dict) = dict
floatingDict :: FloatingType a -> FloatingDict a
floatingDict (TypeHalf dict) = dict
floatingDict (TypeFloat dict) = dict
floatingDict (TypeDouble dict) = dict
floatingDict (TypeCFloat dict) = dict
floatingDict (TypeCDouble dict) = dict
nonNumDict :: NonNumType a -> NonNumDict a
nonNumDict (TypeBool dict) = dict
nonNumDict (TypeChar dict) = dict
nonNumDict (TypeCChar dict) = dict
nonNumDict (TypeCSChar dict) = dict
nonNumDict (TypeCUChar dict) = dict
data TupleType a where
TypeRunit :: TupleType ()
TypeRscalar :: ScalarType a -> TupleType a
TypeRpair :: TupleType a -> TupleType b -> TupleType (a, b)
instance Show (TupleType a) where
show TypeRunit = "()"
show (TypeRscalar t) = show t
show (TypeRpair a b) = printf "(%s,%s)" (show a) (show b)
type BitSizeEq a b = (BitSize a == BitSize b) ~ 'True
type family BitSize a :: Nat
data V2 a = V2 !a !a
deriving (Typeable, Eq, Ord)
data V3 a = V3 !a !a !a
deriving (Typeable, Eq, Ord)
data V4 a = V4 !a !a !a !a
deriving (Typeable, Eq, Ord)
data V8 a = V8 !a !a !a !a !a !a !a !a
deriving (Typeable, Eq, Ord)
data V16 a = V16 !a !a !a !a !a !a !a !a !a !a !a !a !a !a !a !a
deriving (Typeable, Eq, Ord)
instance Show a => Show (V2 a) where
show (V2 a b) = printf "<%s,%s>" (show a) (show b)
instance Show a => Show (V3 a) where
show (V3 a b c) = printf "<%s,%s,%s>" (show a) (show b) (show c)
instance Show a => Show (V4 a) where
show (V4 a b c d) = printf "<%s,%s,%s,%s>" (show a) (show b) (show c) (show d)
instance Show a => Show (V8 a) where
show (V8 a b c d e f g h) =
printf "<%s,%s,%s,%s,%s,%s,%s,%s>"
(show a) (show b) (show c) (show d) (show e) (show f) (show g) (show h)
instance Show a => Show (V16 a) where
show (V16 a b c d e f g h i j k l m n o p) =
printf "<%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s>"
(show a) (show b) (show c) (show d) (show e) (show f) (show g) (show h)
(show i) (show j) (show k) (show l) (show m) (show n) (show o) (show p)
$( runQ $ do
let
bits :: FiniteBits b => b -> Integer
bits = toInteger . finiteBitSize
integralTypes :: [(Name, Integer)]
integralTypes =
[ (''Int, bits (undefined::Int))
, (''Int8, 8)
, (''Int16, 16)
, (''Int32, 32)
, (''Int64, 64)
, (''Word, bits (undefined::Word))
, (''Word8, 8)
, (''Word16, 16)
, (''Word32, 32)
, (''Word64, 64)
, (''CShort, 16)
, (''CUShort, 16)
, (''CInt, 32)
, (''CUInt, 32)
, (''CLong, bits (undefined::CLong))
, (''CULong, bits (undefined::CULong))
, (''CLLong, 64)
, (''CULLong, 64)
]
floatingTypes :: [(Name, Integer)]
floatingTypes =
[ (''Half, 16)
, (''Float, 32)
, (''Double, 64)
, (''CFloat, 32)
, (''CDouble, 64)
]
nonNumTypes :: [(Name, Integer)]
nonNumTypes =
[ (''Bool, 8)
, (''Char, 32)
, (''CChar, 8)
, (''CSChar, 8)
, (''CUChar, 8)
]
mkIntegral :: Name -> Integer -> Q [Dec]
mkIntegral t n =
[d| instance IsIntegral $(conT t) where
integralType = $(conE (mkName ("Type" ++ nameBase t))) IntegralDict
instance IsNum $(conT t) where
numType = IntegralNumType integralType
instance IsBounded $(conT t) where
boundedType = IntegralBoundedType integralType
instance IsSingle $(conT t) where
singleType = NumSingleType numType
instance IsScalar $(conT t) where
scalarType = SingleScalarType singleType
type instance BitSize $(conT t) = $(litT (numTyLit n))
|]
mkFloating :: Name -> Integer -> Q [Dec]
mkFloating t n =
[d| instance IsFloating $(conT t) where
floatingType = $(conE (mkName ("Type" ++ nameBase t))) FloatingDict
instance IsNum $(conT t) where
numType = FloatingNumType floatingType
instance IsSingle $(conT t) where
singleType = NumSingleType numType
instance IsScalar $(conT t) where
scalarType = SingleScalarType singleType
type instance BitSize $(conT t) = $(litT (numTyLit n))
|]
mkNonNum :: Name -> Integer -> Q [Dec]
mkNonNum t n =
[d| instance IsNonNum $(conT t) where
nonNumType = $(conE (mkName ("Type" ++ nameBase t))) NonNumDict
instance IsBounded $(conT t) where
boundedType = NonNumBoundedType nonNumType
instance IsSingle $(conT t) where
singleType = NonNumSingleType nonNumType
instance IsScalar $(conT t) where
scalarType = SingleScalarType singleType
type instance BitSize $(conT t) = $(litT (numTyLit n))
|]
mkVector :: Name -> Integer -> Q [Dec]
mkVector t n =
[d| instance IsScalar (V2 $(conT t)) where
scalarType = VectorScalarType (Vector2Type singleType)
instance IsScalar (V3 $(conT t)) where
scalarType = VectorScalarType (Vector3Type singleType)
instance IsScalar (V4 $(conT t)) where
scalarType = VectorScalarType (Vector4Type singleType)
instance IsScalar (V8 $(conT t)) where
scalarType = VectorScalarType (Vector8Type singleType)
instance IsScalar (V16 $(conT t)) where
scalarType = VectorScalarType (Vector16Type singleType)
type instance BitSize (V2 $(conT t)) = $(litT (numTyLit (2*n)))
type instance BitSize (V3 $(conT t)) = $(litT (numTyLit (3*n)))
type instance BitSize (V4 $(conT t)) = $(litT (numTyLit (4*n)))
type instance BitSize (V8 $(conT t)) = $(litT (numTyLit (8*n)))
type instance BitSize (V16 $(conT t)) = $(litT (numTyLit (16*n)))
|]
is <- mapM (uncurry mkIntegral) integralTypes
fs <- mapM (uncurry mkFloating) floatingTypes
ns <- mapM (uncurry mkNonNum) nonNumTypes
vs <- mapM (uncurry mkVector) (integralTypes ++ floatingTypes ++ nonNumTypes)
return (concat is ++ concat fs ++ concat ns ++ concat vs)
)