module CmmType
    ( CmmType   
    , b8, b16, b32, b64, b128, b256, b512, f32, f64, bWord, bHalfWord, gcWord
    , cInt
    , cmmBits, cmmFloat
    , typeWidth, cmmEqType, cmmEqType_ignoring_ptrhood
    , isFloatType, isGcPtrType, isBitsType
    , isWord32, isWord64, isFloat64, isFloat32
    , Width(..)
    , widthInBits, widthInBytes, widthInLog, widthFromBytes
    , wordWidth, halfWordWidth, cIntWidth
    , halfWordMask
    , narrowU, narrowS
    , rEP_CostCentreStack_mem_alloc
    , rEP_CostCentreStack_scc_count
    , rEP_StgEntCounter_allocs
    , rEP_StgEntCounter_allocd
    , ForeignHint(..)
    , Length
    , vec, vec2, vec4, vec8, vec16
    , vec2f64, vec2b64, vec4f32, vec4b32, vec8b16, vec16b8
    , cmmVec
    , vecLength, vecElemType
    , isVecType
   )
where
import GhcPrelude
import DynFlags
import FastString
import Outputable
import Data.Word
import Data.Int
  
  
  
  
  
data CmmType    
  = CmmType CmmCat Width
data CmmCat                
   = GcPtrCat              
   | BitsCat               
   | FloatCat              
   | VecCat Length CmmCat  
   deriving( Eq )
        
instance Outputable CmmType where
  ppr (CmmType cat wid) = ppr cat <> ppr (widthInBits wid)
instance Outputable CmmCat where
  ppr FloatCat       = text "F"
  ppr GcPtrCat       = text "P"
  ppr BitsCat        = text "I"
  ppr (VecCat n cat) = ppr cat <> text "x" <> ppr n <> text "V"
cmmEqType :: CmmType -> CmmType -> Bool 
cmmEqType (CmmType c1 w1) (CmmType c2 w2) = c1==c2 && w1==w2
cmmEqType_ignoring_ptrhood :: CmmType -> CmmType -> Bool
  
  
cmmEqType_ignoring_ptrhood (CmmType c1 w1) (CmmType c2 w2)
   = c1 `weak_eq` c2 && w1==w2
   where
     weak_eq :: CmmCat -> CmmCat -> Bool
     FloatCat         `weak_eq` FloatCat         = True
     FloatCat         `weak_eq` _other           = False
     _other           `weak_eq` FloatCat         = False
     (VecCat l1 cat1) `weak_eq` (VecCat l2 cat2) = l1 == l2
                                                   && cat1 `weak_eq` cat2
     (VecCat {})      `weak_eq` _other           = False
     _other           `weak_eq` (VecCat {})      = False
     _word1           `weak_eq` _word2           = True        
typeWidth :: CmmType -> Width
typeWidth (CmmType _ w) = w
cmmBits, cmmFloat :: Width -> CmmType
cmmBits  = CmmType BitsCat
cmmFloat = CmmType FloatCat
b8, b16, b32, b64, b128, b256, b512, f32, f64 :: CmmType
b8     = cmmBits W8
b16    = cmmBits W16
b32    = cmmBits W32
b64    = cmmBits W64
b128   = cmmBits W128
b256   = cmmBits W256
b512   = cmmBits W512
f32    = cmmFloat W32
f64    = cmmFloat W64
bWord :: DynFlags -> CmmType
bWord dflags = cmmBits (wordWidth dflags)
bHalfWord :: DynFlags -> CmmType
bHalfWord dflags = cmmBits (halfWordWidth dflags)
gcWord :: DynFlags -> CmmType
gcWord dflags = CmmType GcPtrCat (wordWidth dflags)
cInt :: DynFlags -> CmmType
cInt dflags = cmmBits (cIntWidth  dflags)
isFloatType, isGcPtrType, isBitsType :: CmmType -> Bool
isFloatType (CmmType FloatCat    _) = True
isFloatType _other                  = False
isGcPtrType (CmmType GcPtrCat _) = True
isGcPtrType _other               = False
isBitsType (CmmType BitsCat _) = True
isBitsType _                   = False
isWord32, isWord64, isFloat32, isFloat64 :: CmmType -> Bool
isWord64 (CmmType BitsCat  W64) = True
isWord64 (CmmType GcPtrCat W64) = True
isWord64 _other                 = False
isWord32 (CmmType BitsCat  W32) = True
isWord32 (CmmType GcPtrCat W32) = True
isWord32 _other                 = False
isFloat32 (CmmType FloatCat W32) = True
isFloat32 _other                 = False
isFloat64 (CmmType FloatCat W64) = True
isFloat64 _other                 = False
data Width   = W8 | W16 | W32 | W64
             | W128
             | W256
             | W512
             deriving (Eq, Ord, Show)
instance Outputable Width where
   ppr rep = ptext (mrStr rep)
mrStr :: Width -> PtrString
mrStr W8   = sLit("W8")
mrStr W16  = sLit("W16")
mrStr W32  = sLit("W32")
mrStr W64  = sLit("W64")
mrStr W128 = sLit("W128")
mrStr W256 = sLit("W256")
mrStr W512 = sLit("W512")
wordWidth :: DynFlags -> Width
wordWidth dflags
 | wORD_SIZE dflags == 4 = W32
 | wORD_SIZE dflags == 8 = W64
 | otherwise             = panic "MachOp.wordRep: Unknown word size"
halfWordWidth :: DynFlags -> Width
halfWordWidth dflags
 | wORD_SIZE dflags == 4 = W16
 | wORD_SIZE dflags == 8 = W32
 | otherwise             = panic "MachOp.halfWordRep: Unknown word size"
halfWordMask :: DynFlags -> Integer
halfWordMask dflags
 | wORD_SIZE dflags == 4 = 0xFFFF
 | wORD_SIZE dflags == 8 = 0xFFFFFFFF
 | otherwise             = panic "MachOp.halfWordMask: Unknown word size"
cIntWidth :: DynFlags -> Width
cIntWidth dflags = case cINT_SIZE dflags of
                   4 -> W32
                   8 -> W64
                   s -> panic ("cIntWidth: Unknown cINT_SIZE: " ++ show s)
widthInBits :: Width -> Int
widthInBits W8   = 8
widthInBits W16  = 16
widthInBits W32  = 32
widthInBits W64  = 64
widthInBits W128 = 128
widthInBits W256 = 256
widthInBits W512 = 512
widthInBytes :: Width -> Int
widthInBytes W8   = 1
widthInBytes W16  = 2
widthInBytes W32  = 4
widthInBytes W64  = 8
widthInBytes W128 = 16
widthInBytes W256 = 32
widthInBytes W512 = 64
widthFromBytes :: Int -> Width
widthFromBytes 1  = W8
widthFromBytes 2  = W16
widthFromBytes 4  = W32
widthFromBytes 8  = W64
widthFromBytes 16 = W128
widthFromBytes 32 = W256
widthFromBytes 64 = W512
widthFromBytes n  = pprPanic "no width for given number of bytes" (ppr n)
widthInLog :: Width -> Int
widthInLog W8   = 0
widthInLog W16  = 1
widthInLog W32  = 2
widthInLog W64  = 3
widthInLog W128 = 4
widthInLog W256 = 5
widthInLog W512 = 6
narrowU :: Width -> Integer -> Integer
narrowU W8  x = fromIntegral (fromIntegral x :: Word8)
narrowU W16 x = fromIntegral (fromIntegral x :: Word16)
narrowU W32 x = fromIntegral (fromIntegral x :: Word32)
narrowU W64 x = fromIntegral (fromIntegral x :: Word64)
narrowU _ _ = panic "narrowTo"
narrowS :: Width -> Integer -> Integer
narrowS W8  x = fromIntegral (fromIntegral x :: Int8)
narrowS W16 x = fromIntegral (fromIntegral x :: Int16)
narrowS W32 x = fromIntegral (fromIntegral x :: Int32)
narrowS W64 x = fromIntegral (fromIntegral x :: Int64)
narrowS _ _ = panic "narrowTo"
type Length = Int
vec :: Length -> CmmType -> CmmType
vec l (CmmType cat w) = CmmType (VecCat l cat) vecw
  where
    vecw :: Width
    vecw = widthFromBytes (l*widthInBytes w)
vec2, vec4, vec8, vec16 :: CmmType -> CmmType
vec2  = vec 2
vec4  = vec 4
vec8  = vec 8
vec16 = vec 16
vec2f64, vec2b64, vec4f32, vec4b32, vec8b16, vec16b8 :: CmmType
vec2f64 = vec 2 f64
vec2b64 = vec 2 b64
vec4f32 = vec 4 f32
vec4b32 = vec 4 b32
vec8b16 = vec 8 b16
vec16b8 = vec 16 b8
cmmVec :: Int -> CmmType -> CmmType
cmmVec n (CmmType cat w) =
    CmmType (VecCat n cat) (widthFromBytes (n*widthInBytes w))
vecLength :: CmmType -> Length
vecLength (CmmType (VecCat l _) _) = l
vecLength _                        = panic "vecLength: not a vector"
vecElemType :: CmmType -> CmmType
vecElemType (CmmType (VecCat l cat) w) = CmmType cat scalw
  where
    scalw :: Width
    scalw = widthFromBytes (widthInBytes w `div` l)
vecElemType _ = panic "vecElemType: not a vector"
isVecType :: CmmType -> Bool
isVecType (CmmType (VecCat {}) _) = True
isVecType _                       = False
data ForeignHint
  = NoHint | AddrHint | SignedHint
  deriving( Eq )
        
        
rEP_CostCentreStack_mem_alloc :: DynFlags -> CmmType
rEP_CostCentreStack_mem_alloc dflags
    = cmmBits (widthFromBytes (pc_REP_CostCentreStack_mem_alloc pc))
    where pc = platformConstants dflags
rEP_CostCentreStack_scc_count :: DynFlags -> CmmType
rEP_CostCentreStack_scc_count dflags
    = cmmBits (widthFromBytes (pc_REP_CostCentreStack_scc_count pc))
    where pc = platformConstants dflags
rEP_StgEntCounter_allocs :: DynFlags -> CmmType
rEP_StgEntCounter_allocs dflags
    = cmmBits (widthFromBytes (pc_REP_StgEntCounter_allocs pc))
    where pc = platformConstants dflags
rEP_StgEntCounter_allocd :: DynFlags -> CmmType
rEP_StgEntCounter_allocd dflags
    = cmmBits (widthFromBytes (pc_REP_StgEntCounter_allocd pc))
    where pc = platformConstants dflags