{-# LANGUAGE ConstraintKinds      #-}
{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE DeriveDataTypeable   #-}
{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE TemplateHaskell      #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE TypeOperators        #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Type
-- Copyright   : [2008..2018] Manuel M T Chakravarty, Gabriele Keller
--               [2009..2018] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
--  /Scalar types supported in array computations/
--
--  Integral types:
--    * Int
--    * Int8
--    * Int16
--    * Int32
--    * Int64
--    * Word
--    * Word8
--    * Word16
--    * Word32
--    * Word64
--    * CShort
--    * CUShort
--    * CInt
--    * CUInt
--    * CLong
--    * CULong
--    * CLLong
--    * CULLong
--
--  Floating types:
--    * Half
--    * Float
--    * Double
--    * CFloat
--    * CDouble
--
--  Non-numeric types:
--    * Bool
--    * Char
--    * CChar
--    * CSChar
--    * CUChar
--
--  SIMD vector types:
--    * V2
--    * V3
--    * V4
--    * V8
--    * V16
--
-- Note that 'Int' has the same bit width as in plain Haskell computations.
-- 'Float' and 'Double' represent IEEE single and double precision floating
-- point numbers, respectively.
--

module Data.Array.Accelerate.Type (
  Half(..), Float, Double, Char, Bool(..),
  module Data.Int,
  module Data.Word,
  module Foreign.C.Types,
  module Data.Array.Accelerate.Type
) where


import Data.Orphans ()    -- orphan instances for 8-tuples and beyond

-- standard libraries
import Data.Bits
import Data.Int
import Data.Type.Equality
import Data.Typeable
import Data.Word
import GHC.TypeLits
import Language.Haskell.TH
import Numeric.Half
import Text.Printf
import Foreign.Storable
import Foreign.C.Types
    (CChar, CSChar, CUChar, CShort, CUShort, CInt, CUInt, CLong, CULong, CLLong, CULLong, CFloat, CDouble)


-- Scalar types
-- ------------

-- Reified dictionaries
--

data IntegralDict a where
  IntegralDict :: ( Bounded a, Eq a, Ord a, Show a
                  , Bits a, FiniteBits a, Integral a, Num a, Real a, Storable a )
               => IntegralDict a

data FloatingDict a where
  FloatingDict :: ( Eq a, Ord a, Show a
                  , Floating a, Fractional a, Num a, Real a, RealFrac a
                  , RealFloat a, Storable a )
               => FloatingDict a

data NonNumDict a where
  NonNumDict :: ( Bounded a, Eq a, Ord a, Show a, Storable a )
             => NonNumDict a


-- Scalar type representation
--

-- | Integral types supported in array computations.
--
data IntegralType a where
  TypeInt     :: IntegralDict Int     -> IntegralType Int
  TypeInt8    :: IntegralDict Int8    -> IntegralType Int8
  TypeInt16   :: IntegralDict Int16   -> IntegralType Int16
  TypeInt32   :: IntegralDict Int32   -> IntegralType Int32
  TypeInt64   :: IntegralDict Int64   -> IntegralType Int64
  TypeWord    :: IntegralDict Word    -> IntegralType Word
  TypeWord8   :: IntegralDict Word8   -> IntegralType Word8
  TypeWord16  :: IntegralDict Word16  -> IntegralType Word16
  TypeWord32  :: IntegralDict Word32  -> IntegralType Word32
  TypeWord64  :: IntegralDict Word64  -> IntegralType Word64
  TypeCShort  :: IntegralDict CShort  -> IntegralType CShort
  TypeCUShort :: IntegralDict CUShort -> IntegralType CUShort
  TypeCInt    :: IntegralDict CInt    -> IntegralType CInt
  TypeCUInt   :: IntegralDict CUInt   -> IntegralType CUInt
  TypeCLong   :: IntegralDict CLong   -> IntegralType CLong
  TypeCULong  :: IntegralDict CULong  -> IntegralType CULong
  TypeCLLong  :: IntegralDict CLLong  -> IntegralType CLLong
  TypeCULLong :: IntegralDict CULLong -> IntegralType CULLong

-- | Floating-point types supported in array computations.
--
data FloatingType a where
  TypeHalf    :: FloatingDict Half    -> FloatingType Half
  TypeFloat   :: FloatingDict Float   -> FloatingType Float
  TypeDouble  :: FloatingDict Double  -> FloatingType Double
  TypeCFloat  :: FloatingDict CFloat  -> FloatingType CFloat
  TypeCDouble :: FloatingDict CDouble -> FloatingType CDouble

-- | Non-numeric types supported in array computations.
--
data NonNumType a where
  TypeBool    :: NonNumDict Bool      -> NonNumType Bool   --  marshalled to Word8
  TypeChar    :: NonNumDict Char      -> NonNumType Char
  TypeCChar   :: NonNumDict CChar     -> NonNumType CChar
  TypeCSChar  :: NonNumDict CSChar    -> NonNumType CSChar
  TypeCUChar  :: NonNumDict CUChar    -> NonNumType CUChar

-- | Numeric element types implement Num & Real
--
data NumType a where
  IntegralNumType :: IntegralType a -> NumType a
  FloatingNumType :: FloatingType a -> NumType a

-- | Bounded element types implement Bounded
--
data BoundedType a where
  IntegralBoundedType :: IntegralType a -> BoundedType a
  NonNumBoundedType   :: NonNumType a   -> BoundedType a

-- | All scalar element types implement Eq & Ord
--
data ScalarType a where
  SingleScalarType :: SingleType a     -> ScalarType a
  VectorScalarType :: VectorType (v a) -> ScalarType (v a)

data SingleType a where
  NumSingleType    :: NumType a    -> SingleType a
  NonNumSingleType :: NonNumType a -> SingleType a

data VectorType v where
  Vector2Type   :: SingleType a -> VectorType (V2 a)
  Vector3Type   :: SingleType a -> VectorType (V3 a)
  Vector4Type   :: SingleType a -> VectorType (V4 a)
  Vector8Type   :: SingleType a -> VectorType (V8 a)
  Vector16Type  :: SingleType a -> VectorType (V16 a)

-- Showing type names
--

instance Show (IntegralType a) where
  show (TypeInt _)     = "Int"
  show (TypeInt8 _)    = "Int8"
  show (TypeInt16 _)   = "Int16"
  show (TypeInt32 _)   = "Int32"
  show (TypeInt64 _)   = "Int64"
  show (TypeWord _)    = "Word"
  show (TypeWord8 _)   = "Word8"
  show (TypeWord16 _)  = "Word16"
  show (TypeWord32 _)  = "Word32"
  show (TypeWord64 _)  = "Word64"
  show (TypeCShort _)  = "CShort"
  show (TypeCUShort _) = "CUShort"
  show (TypeCInt _)    = "CInt"
  show (TypeCUInt _)   = "CUInt"
  show (TypeCLong _)   = "CLong"
  show (TypeCULong _)  = "CULong"
  show (TypeCLLong _)  = "CLLong"
  show (TypeCULLong _) = "CULLong"

instance Show (FloatingType a) where
  show (TypeHalf _)    = "Half"
  show (TypeFloat _)   = "Float"
  show (TypeDouble _)  = "Double"
  show (TypeCFloat _)  = "CFloat"
  show (TypeCDouble _) = "CDouble"

instance Show (NonNumType a) where
  show (TypeBool _)   = "Bool"
  show (TypeChar _)   = "Char"
  show (TypeCChar _)  = "CChar"
  show (TypeCSChar _) = "CSChar"
  show (TypeCUChar _) = "CUChar"

instance Show (NumType a) where
  show (IntegralNumType ty) = show ty
  show (FloatingNumType ty) = show ty

instance Show (BoundedType a) where
  show (IntegralBoundedType ty) = show ty
  show (NonNumBoundedType ty)   = show ty

instance Show (SingleType a) where
  show (NumSingleType ty)    = show ty
  show (NonNumSingleType ty) = show ty

instance Show (VectorType a) where
  show (Vector2Type t)  = printf "<2 x %s>" (show t)
  show (Vector3Type t)  = printf "<3 x %s>" (show t)
  show (Vector4Type t)  = printf "<4 x %s>" (show t)
  show (Vector8Type t)  = printf "<8 x %s>" (show t)
  show (Vector16Type t) = printf "<16 x %s>" (show t)

instance Show (ScalarType a) where
  show (SingleScalarType ty) = show ty
  show (VectorScalarType ty) = show ty


-- Querying scalar type representations
--

-- | Integral types
--
class (IsSingle a, IsNum a, IsBounded a) => IsIntegral a where
  integralType :: IntegralType a

-- | Floating types
--
class (Floating a, IsSingle a, IsNum a) => IsFloating a where
  floatingType :: FloatingType a

-- | Non-numeric types
--
class IsNonNum a where
  nonNumType :: NonNumType a

-- | Numeric types
--
class (Num a, IsSingle a) => IsNum a where
  numType :: NumType a

-- | Bounded types
--
class IsBounded a where
  boundedType :: BoundedType a

-- | All single value types
--
class IsScalar a => IsSingle a where
  singleType :: SingleType a

-- | All scalar types
--
class Typeable a => IsScalar a where
  scalarType :: ScalarType a


-- Extract reified dictionaries
--

integralDict :: IntegralType a -> IntegralDict a
integralDict (TypeInt     dict) = dict
integralDict (TypeInt8    dict) = dict
integralDict (TypeInt16   dict) = dict
integralDict (TypeInt32   dict) = dict
integralDict (TypeInt64   dict) = dict
integralDict (TypeWord    dict) = dict
integralDict (TypeWord8   dict) = dict
integralDict (TypeWord16  dict) = dict
integralDict (TypeWord32  dict) = dict
integralDict (TypeWord64  dict) = dict
integralDict (TypeCShort  dict) = dict
integralDict (TypeCUShort dict) = dict
integralDict (TypeCInt    dict) = dict
integralDict (TypeCUInt   dict) = dict
integralDict (TypeCLong   dict) = dict
integralDict (TypeCULong  dict) = dict
integralDict (TypeCLLong  dict) = dict
integralDict (TypeCULLong dict) = dict

floatingDict :: FloatingType a -> FloatingDict a
floatingDict (TypeHalf    dict) = dict
floatingDict (TypeFloat   dict) = dict
floatingDict (TypeDouble  dict) = dict
floatingDict (TypeCFloat  dict) = dict
floatingDict (TypeCDouble dict) = dict

nonNumDict :: NonNumType a -> NonNumDict a
nonNumDict (TypeBool   dict) = dict
nonNumDict (TypeChar   dict) = dict
nonNumDict (TypeCChar  dict) = dict
nonNumDict (TypeCSChar dict) = dict
nonNumDict (TypeCUChar dict) = dict


-- Type representation
-- -------------------
--
-- Representation of product types, consisting of:
--
--   * unit (void)
--
--   * scalar types: values which go in registers. These may be single value
--     types such as int and float, or SIMD vectors of single value types such
--     as <4 * float>. We do not allow vectors-of-vectors.
--
--   * pairs: representing compound values (i.e. tuples) where each component
--     will be stored in a separate array.
--
data TupleType a where
  TypeRunit   ::                               TupleType ()
  TypeRscalar :: ScalarType a               -> TupleType a
  TypeRpair   :: TupleType a -> TupleType b -> TupleType (a, b)

instance Show (TupleType a) where
  show TypeRunit        = "()"
  show (TypeRscalar t)  = show t
  show (TypeRpair a b)  = printf "(%s,%s)" (show a) (show b)


-- Type-level bit sizes
-- --------------------

-- |Constraint that values of these two types have the same bit width
--
type BitSizeEq a b = (BitSize a == BitSize b) ~ 'True

type family BitSize a :: Nat


-- SIMD vector types
-- -----------------

data V2 a  = V2 !a !a
  deriving (Typeable, Eq, Ord)

data V3 a  = V3 !a !a !a
  deriving (Typeable, Eq, Ord)

data V4 a  = V4 !a !a !a !a
  deriving (Typeable, Eq, Ord)

data V8 a  = V8 !a !a !a !a !a !a !a !a
  deriving (Typeable, Eq, Ord)

data V16 a = V16 !a !a !a !a !a !a !a !a !a !a !a !a !a !a !a !a
  deriving (Typeable, Eq, Ord)

instance Show a => Show (V2 a) where
  show (V2 a b) = printf "<%s,%s>" (show a) (show b)

instance Show a => Show (V3 a) where
  show (V3 a b c) = printf "<%s,%s,%s>" (show a) (show b) (show c)

instance Show a => Show (V4 a) where
  show (V4 a b c d) = printf "<%s,%s,%s,%s>" (show a) (show b) (show c) (show d)

instance Show a => Show (V8 a) where
  show (V8 a b c d e f g h) =
    printf "<%s,%s,%s,%s,%s,%s,%s,%s>"
      (show a) (show b) (show c) (show d) (show e) (show f) (show g) (show h)

instance Show a => Show (V16 a) where
  show (V16 a b c d e f g h i j k l m n o p) =
    printf "<%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s>"
      (show a) (show b) (show c) (show d) (show e) (show f) (show g) (show h)
      (show i) (show j) (show k) (show l) (show m) (show n) (show o) (show p)


-- Instances
-- ---------
--
-- Generate instances for the IsX classes. It would be preferable to do this
-- automatically based on the members of the IntegralType (etc.) representations
-- (see for example FromIntegral.hs) but TH phase restrictions would require us
-- to split this into a separate module.
--

$( runQ $ do
  let
      bits :: FiniteBits b => b -> Integer
      bits = toInteger . finiteBitSize

      integralTypes :: [(Name, Integer)]
      integralTypes =
        [ (''Int,     bits (undefined::Int))
        , (''Int8,    8)
        , (''Int16,   16)
        , (''Int32,   32)
        , (''Int64,   64)
        , (''Word,    bits (undefined::Word))
        , (''Word8,   8)
        , (''Word16,  16)
        , (''Word32,  32)
        , (''Word64,  64)
        , (''CShort,  16)
        , (''CUShort, 16)
        , (''CInt,    32)
        , (''CUInt,   32)
        , (''CLong,   bits (undefined::CLong))
        , (''CULong,  bits (undefined::CULong))
        , (''CLLong,  64)
        , (''CULLong, 64)
        ]

      floatingTypes :: [(Name, Integer)]
      floatingTypes =
        [ (''Half,    16)
        , (''Float,   32)
        , (''Double,  64)
        , (''CFloat,  32)
        , (''CDouble, 64)
        ]

      nonNumTypes :: [(Name, Integer)]
      nonNumTypes =
        [ (''Bool,   8)    -- stored as Word8
        , (''Char,   32)
        , (''CChar,  8)
        , (''CSChar, 8)
        , (''CUChar, 8)
        ]

      mkIntegral :: Name -> Integer -> Q [Dec]
      mkIntegral t n =
        [d| instance IsIntegral $(conT t) where
              integralType = $(conE (mkName ("Type" ++ nameBase t))) IntegralDict

            instance IsNum $(conT t) where
              numType = IntegralNumType integralType

            instance IsBounded $(conT t) where
              boundedType = IntegralBoundedType integralType

            instance IsSingle $(conT t) where
              singleType = NumSingleType numType

            instance IsScalar $(conT t) where
              scalarType = SingleScalarType singleType

            type instance BitSize $(conT t) = $(litT (numTyLit n))
          |]

      mkFloating :: Name -> Integer -> Q [Dec]
      mkFloating t n =
        [d| instance IsFloating $(conT t) where
              floatingType = $(conE (mkName ("Type" ++ nameBase t))) FloatingDict

            instance IsNum $(conT t) where
              numType = FloatingNumType floatingType

            instance IsSingle $(conT t) where
              singleType = NumSingleType numType

            instance IsScalar $(conT t) where
              scalarType = SingleScalarType singleType

            type instance BitSize $(conT t) = $(litT (numTyLit n))
          |]

      mkNonNum :: Name -> Integer -> Q [Dec]
      mkNonNum t n =
        [d| instance IsNonNum $(conT t) where
              nonNumType = $(conE (mkName ("Type" ++ nameBase t))) NonNumDict

            instance IsBounded $(conT t) where
              boundedType = NonNumBoundedType nonNumType

            instance IsSingle $(conT t) where
              singleType = NonNumSingleType nonNumType

            instance IsScalar $(conT t) where
              scalarType = SingleScalarType singleType

            type instance BitSize $(conT t) = $(litT (numTyLit n))
          |]

      mkVector :: Name -> Integer -> Q [Dec]
      mkVector t n =
        [d| instance IsScalar (V2 $(conT t)) where
              scalarType = VectorScalarType (Vector2Type singleType)

            instance IsScalar (V3 $(conT t)) where
              scalarType = VectorScalarType (Vector3Type singleType)

            instance IsScalar (V4 $(conT t)) where
              scalarType = VectorScalarType (Vector4Type singleType)

            instance IsScalar (V8 $(conT t)) where
              scalarType = VectorScalarType (Vector8Type singleType)

            instance IsScalar (V16 $(conT t)) where
              scalarType = VectorScalarType (Vector16Type singleType)

            type instance BitSize (V2 $(conT t))  = $(litT (numTyLit (2*n)))
            type instance BitSize (V3 $(conT t))  = $(litT (numTyLit (3*n)))
            type instance BitSize (V4 $(conT t))  = $(litT (numTyLit (4*n)))
            type instance BitSize (V8 $(conT t))  = $(litT (numTyLit (8*n)))
            type instance BitSize (V16 $(conT t)) = $(litT (numTyLit (16*n)))
          |]
      --
  is <- mapM (uncurry mkIntegral) integralTypes
  fs <- mapM (uncurry mkFloating) floatingTypes
  ns <- mapM (uncurry mkNonNum)   nonNumTypes
  vs <- mapM (uncurry mkVector)  (integralTypes ++ floatingTypes ++ nonNumTypes)
  --
  return (concat is ++ concat fs ++ concat ns ++ concat vs)
 )