{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
module CmmExpr
    ( CmmExpr(..), cmmExprType, cmmExprWidth, cmmExprAlignment, maybeInvertCmmExpr
    , CmmReg(..), cmmRegType, cmmRegWidth
    , CmmLit(..), cmmLitType
    , LocalReg(..), localRegType
    , GlobalReg(..), isArgReg, globalRegType
    , spReg, hpReg, spLimReg, hpLimReg, nodeReg
    , currentTSOReg, currentNurseryReg, hpAllocReg, cccsReg
    , node, baseReg
    , VGcPtr(..)
    , DefinerOfRegs, UserOfRegs
    , foldRegsDefd, foldRegsUsed
    , foldLocalRegsDefd, foldLocalRegsUsed
    , RegSet, LocalRegSet, GlobalRegSet
    , emptyRegSet, elemRegSet, extendRegSet, deleteFromRegSet, mkRegSet
    , plusRegSet, minusRegSet, timesRegSet, sizeRegSet, nullRegSet
    , regSetToList
    , Area(..)
    , module CmmMachOp
    , module CmmType
    )
where
import GhcPrelude
import BlockId
import CLabel
import CmmMachOp
import CmmType
import DynFlags
import Outputable (panic)
import Unique
import Data.Set (Set)
import qualified Data.Set as Set
import BasicTypes (Alignment, mkAlignment, alignmentOf)
data CmmExpr
  = CmmLit CmmLit               
  | CmmLoad !CmmExpr !CmmType   
  | CmmReg !CmmReg              
  | CmmMachOp MachOp [CmmExpr]  
  | CmmStackSlot Area {-# UNPACK #-} !Int
                                
                                
  | CmmRegOff !CmmReg Int
        
        
        
        
instance Eq CmmExpr where       
  CmmLit l1          == CmmLit l2          = l1==l2
  CmmLoad e1 _       == CmmLoad e2 _       = e1==e2
  CmmReg r1          == CmmReg r2          = r1==r2
  CmmRegOff r1 i1    == CmmRegOff r2 i2    = r1==r2 && i1==i2
  CmmMachOp op1 es1  == CmmMachOp op2 es2  = op1==op2 && es1==es2
  CmmStackSlot a1 i1 == CmmStackSlot a2 i2 = a1==a2 && i1==i2
  _e1                == _e2                = False
data CmmReg
  = CmmLocal  {-# UNPACK #-} !LocalReg
  | CmmGlobal GlobalReg
  deriving( Eq, Ord )
data Area
  = Old            
  | Young {-# UNPACK #-} !BlockId  
                   
  deriving (Eq, Ord)
data CmmLit
  = CmmInt !Integer  Width
        
        
        
        
        
  | CmmFloat  Rational Width
  | CmmVec [CmmLit]                     
  | CmmLabel    CLabel                  
  | CmmLabelOff CLabel Int              
        
        
        
        
        
        
        
  | CmmLabelDiffOff CLabel CLabel Int Width 
        
        
        
        
        
        
        
  | CmmBlock {-# UNPACK #-} !BlockId     
        
        
  | CmmHighStackMark 
                     
                     
                     
                     
  deriving Eq
cmmExprType :: DynFlags -> CmmExpr -> CmmType
cmmExprType dflags (CmmLit lit)        = cmmLitType dflags lit
cmmExprType _      (CmmLoad _ rep)     = rep
cmmExprType dflags (CmmReg reg)        = cmmRegType dflags reg
cmmExprType dflags (CmmMachOp op args) = machOpResultType dflags op (map (cmmExprType dflags) args)
cmmExprType dflags (CmmRegOff reg _)   = cmmRegType dflags reg
cmmExprType dflags (CmmStackSlot _ _)  = bWord dflags 
cmmLitType :: DynFlags -> CmmLit -> CmmType
cmmLitType _      (CmmInt _ width)     = cmmBits  width
cmmLitType _      (CmmFloat _ width)   = cmmFloat width
cmmLitType _      (CmmVec [])          = panic "cmmLitType: CmmVec []"
cmmLitType cflags (CmmVec (l:ls))      = let ty = cmmLitType cflags l
                                         in if all (`cmmEqType` ty) (map (cmmLitType cflags) ls)
                                            then cmmVec (1+length ls) ty
                                            else panic "cmmLitType: CmmVec"
cmmLitType dflags (CmmLabel lbl)       = cmmLabelType dflags lbl
cmmLitType dflags (CmmLabelOff lbl _)  = cmmLabelType dflags lbl
cmmLitType _      (CmmLabelDiffOff _ _ _ width) = cmmBits width
cmmLitType dflags (CmmBlock _)         = bWord dflags
cmmLitType dflags (CmmHighStackMark)   = bWord dflags
cmmLabelType :: DynFlags -> CLabel -> CmmType
cmmLabelType dflags lbl
 | isGcPtrLabel lbl = gcWord dflags
 | otherwise        = bWord dflags
cmmExprWidth :: DynFlags -> CmmExpr -> Width
cmmExprWidth dflags e = typeWidth (cmmExprType dflags e)
cmmExprAlignment :: CmmExpr -> Alignment
cmmExprAlignment (CmmLit (CmmInt intOff _)) = alignmentOf (fromInteger intOff)
cmmExprAlignment _                          = mkAlignment 1
maybeInvertCmmExpr :: CmmExpr -> Maybe CmmExpr
maybeInvertCmmExpr (CmmMachOp op args) = do op' <- maybeInvertComparison op
                                            return (CmmMachOp op' args)
maybeInvertCmmExpr _ = Nothing
data LocalReg
  = LocalReg {-# UNPACK #-} !Unique CmmType
    
    
    
instance Eq LocalReg where
  (LocalReg u1 _) == (LocalReg u2 _) = u1 == u2
instance Ord LocalReg where
  compare (LocalReg u1 _) (LocalReg u2 _) = nonDetCmpUnique u1 u2
instance Uniquable LocalReg where
  getUnique (LocalReg uniq _) = uniq
cmmRegType :: DynFlags -> CmmReg -> CmmType
cmmRegType _      (CmmLocal  reg) = localRegType reg
cmmRegType dflags (CmmGlobal reg) = globalRegType dflags reg
cmmRegWidth :: DynFlags -> CmmReg -> Width
cmmRegWidth dflags = typeWidth . cmmRegType dflags
localRegType :: LocalReg -> CmmType
localRegType (LocalReg _ rep) = rep
type RegSet r     = Set r
type LocalRegSet  = RegSet LocalReg
type GlobalRegSet = RegSet GlobalReg
emptyRegSet             :: RegSet r
nullRegSet              :: RegSet r -> Bool
elemRegSet              :: Ord r => r -> RegSet r -> Bool
extendRegSet            :: Ord r => RegSet r -> r -> RegSet r
deleteFromRegSet        :: Ord r => RegSet r -> r -> RegSet r
mkRegSet                :: Ord r => [r] -> RegSet r
minusRegSet, plusRegSet, timesRegSet :: Ord r => RegSet r -> RegSet r -> RegSet r
sizeRegSet              :: RegSet r -> Int
regSetToList            :: RegSet r -> [r]
emptyRegSet      = Set.empty
nullRegSet       = Set.null
elemRegSet       = Set.member
extendRegSet     = flip Set.insert
deleteFromRegSet = flip Set.delete
mkRegSet         = Set.fromList
minusRegSet      = Set.difference
plusRegSet       = Set.union
timesRegSet      = Set.intersection
sizeRegSet       = Set.size
regSetToList     = Set.toList
class Ord r => UserOfRegs r a where
  foldRegsUsed :: DynFlags -> (b -> r -> b) -> b -> a -> b
foldLocalRegsUsed :: UserOfRegs LocalReg a
                  => DynFlags -> (b -> LocalReg -> b) -> b -> a -> b
foldLocalRegsUsed = foldRegsUsed
class Ord r => DefinerOfRegs r a where
  foldRegsDefd :: DynFlags -> (b -> r -> b) -> b -> a -> b
foldLocalRegsDefd :: DefinerOfRegs LocalReg a
                  => DynFlags -> (b -> LocalReg -> b) -> b -> a -> b
foldLocalRegsDefd = foldRegsDefd
instance UserOfRegs LocalReg CmmReg where
    foldRegsUsed _ f z (CmmLocal reg) = f z reg
    foldRegsUsed _ _ z (CmmGlobal _)  = z
instance DefinerOfRegs LocalReg CmmReg where
    foldRegsDefd _ f z (CmmLocal reg) = f z reg
    foldRegsDefd _ _ z (CmmGlobal _)  = z
instance UserOfRegs GlobalReg CmmReg where
    foldRegsUsed _ _ z (CmmLocal _)    = z
    foldRegsUsed _ f z (CmmGlobal reg) = f z reg
instance DefinerOfRegs GlobalReg CmmReg where
    foldRegsDefd _ _ z (CmmLocal _)    = z
    foldRegsDefd _ f z (CmmGlobal reg) = f z reg
instance Ord r => UserOfRegs r r where
    foldRegsUsed _ f z r = f z r
instance Ord r => DefinerOfRegs r r where
    foldRegsDefd _ f z r = f z r
instance (Ord r, UserOfRegs r CmmReg) => UserOfRegs r CmmExpr where
  
  
  foldRegsUsed dflags f !z e = expr z e
    where expr z (CmmLit _)          = z
          expr z (CmmLoad addr _)    = foldRegsUsed dflags f z addr
          expr z (CmmReg r)          = foldRegsUsed dflags f z r
          expr z (CmmMachOp _ exprs) = foldRegsUsed dflags f z exprs
          expr z (CmmRegOff r _)     = foldRegsUsed dflags f z r
          expr z (CmmStackSlot _ _)  = z
instance UserOfRegs r a => UserOfRegs r [a] where
  foldRegsUsed dflags f set as = foldl' (foldRegsUsed dflags f) set as
  {-# INLINABLE foldRegsUsed #-}
instance DefinerOfRegs r a => DefinerOfRegs r [a] where
  foldRegsDefd dflags f set as = foldl' (foldRegsDefd dflags f) set as
  {-# INLINABLE foldRegsDefd #-}
data VGcPtr = VGcPtr | VNonGcPtr deriving( Eq, Show )
data GlobalReg
  
  = VanillaReg                  
        {-# UNPACK #-} !Int     
        VGcPtr
  | FloatReg            
        {-# UNPACK #-} !Int     
  | DoubleReg           
        {-# UNPACK #-} !Int     
  | LongReg             
        {-# UNPACK #-} !Int     
  | XmmReg                      
        {-# UNPACK #-} !Int     
  | YmmReg                      
        {-# UNPACK #-} !Int     
  | ZmmReg                      
        {-# UNPACK #-} !Int     
  
  | Sp                  
  | SpLim               
  | Hp                  
  | HpLim               
  | CCCS                
  | CurrentTSO          
  | CurrentNursery      
  | HpAlloc             
                
                
                
  | EagerBlackholeInfo  
  | GCEnter1            
  | GCFun               
  
  
  
  
  | BaseReg
  
  
  
  | MachSp
  
  
  | UnwindReturnReg
  
  
  
  | PicBaseReg
  deriving( Show )
instance Eq GlobalReg where
   VanillaReg i _ == VanillaReg j _ = i==j 
   FloatReg i == FloatReg j = i==j
   DoubleReg i == DoubleReg j = i==j
   LongReg i == LongReg j = i==j
   
   
   
   XmmReg i == XmmReg j = i==j
   YmmReg i == YmmReg j = i==j
   ZmmReg i == ZmmReg j = i==j
   Sp == Sp = True
   SpLim == SpLim = True
   Hp == Hp = True
   HpLim == HpLim = True
   CCCS == CCCS = True
   CurrentTSO == CurrentTSO = True
   CurrentNursery == CurrentNursery = True
   HpAlloc == HpAlloc = True
   EagerBlackholeInfo == EagerBlackholeInfo = True
   GCEnter1 == GCEnter1 = True
   GCFun == GCFun = True
   BaseReg == BaseReg = True
   MachSp == MachSp = True
   UnwindReturnReg == UnwindReturnReg = True
   PicBaseReg == PicBaseReg = True
   _r1 == _r2 = False
instance Ord GlobalReg where
   compare (VanillaReg i _) (VanillaReg j _) = compare i j
     
   compare (FloatReg i)  (FloatReg  j) = compare i j
   compare (DoubleReg i) (DoubleReg j) = compare i j
   compare (LongReg i)   (LongReg   j) = compare i j
   compare (XmmReg i)    (XmmReg    j) = compare i j
   compare (YmmReg i)    (YmmReg    j) = compare i j
   compare (ZmmReg i)    (ZmmReg    j) = compare i j
   compare Sp Sp = EQ
   compare SpLim SpLim = EQ
   compare Hp Hp = EQ
   compare HpLim HpLim = EQ
   compare CCCS CCCS = EQ
   compare CurrentTSO CurrentTSO = EQ
   compare CurrentNursery CurrentNursery = EQ
   compare HpAlloc HpAlloc = EQ
   compare EagerBlackholeInfo EagerBlackholeInfo = EQ
   compare GCEnter1 GCEnter1 = EQ
   compare GCFun GCFun = EQ
   compare BaseReg BaseReg = EQ
   compare MachSp MachSp = EQ
   compare UnwindReturnReg UnwindReturnReg = EQ
   compare PicBaseReg PicBaseReg = EQ
   compare (VanillaReg _ _) _ = LT
   compare _ (VanillaReg _ _) = GT
   compare (FloatReg _) _     = LT
   compare _ (FloatReg _)     = GT
   compare (DoubleReg _) _    = LT
   compare _ (DoubleReg _)    = GT
   compare (LongReg _) _      = LT
   compare _ (LongReg _)      = GT
   compare (XmmReg _) _       = LT
   compare _ (XmmReg _)       = GT
   compare (YmmReg _) _       = LT
   compare _ (YmmReg _)       = GT
   compare (ZmmReg _) _       = LT
   compare _ (ZmmReg _)       = GT
   compare Sp _ = LT
   compare _ Sp = GT
   compare SpLim _ = LT
   compare _ SpLim = GT
   compare Hp _ = LT
   compare _ Hp = GT
   compare HpLim _ = LT
   compare _ HpLim = GT
   compare CCCS _ = LT
   compare _ CCCS = GT
   compare CurrentTSO _ = LT
   compare _ CurrentTSO = GT
   compare CurrentNursery _ = LT
   compare _ CurrentNursery = GT
   compare HpAlloc _ = LT
   compare _ HpAlloc = GT
   compare GCEnter1 _ = LT
   compare _ GCEnter1 = GT
   compare GCFun _ = LT
   compare _ GCFun = GT
   compare BaseReg _ = LT
   compare _ BaseReg = GT
   compare MachSp _ = LT
   compare _ MachSp = GT
   compare UnwindReturnReg _ = LT
   compare _ UnwindReturnReg = GT
   compare EagerBlackholeInfo _ = LT
   compare _ EagerBlackholeInfo = GT
baseReg, spReg, hpReg, spLimReg, hpLimReg, nodeReg,
  currentTSOReg, currentNurseryReg, hpAllocReg, cccsReg  :: CmmReg
baseReg = CmmGlobal BaseReg
spReg = CmmGlobal Sp
hpReg = CmmGlobal Hp
hpLimReg = CmmGlobal HpLim
spLimReg = CmmGlobal SpLim
nodeReg = CmmGlobal node
currentTSOReg = CmmGlobal CurrentTSO
currentNurseryReg = CmmGlobal CurrentNursery
hpAllocReg = CmmGlobal HpAlloc
cccsReg = CmmGlobal CCCS
node :: GlobalReg
node = VanillaReg 1 VGcPtr
globalRegType :: DynFlags -> GlobalReg -> CmmType
globalRegType dflags (VanillaReg _ VGcPtr)    = gcWord dflags
globalRegType dflags (VanillaReg _ VNonGcPtr) = bWord dflags
globalRegType _      (FloatReg _)      = cmmFloat W32
globalRegType _      (DoubleReg _)     = cmmFloat W64
globalRegType _      (LongReg _)       = cmmBits W64
globalRegType _      (XmmReg _)        = cmmVec 4 (cmmBits W32)
globalRegType _      (YmmReg _)        = cmmVec 8 (cmmBits W32)
globalRegType _      (ZmmReg _)        = cmmVec 16 (cmmBits W32)
globalRegType dflags Hp                = gcWord dflags
                                            
                                            
globalRegType dflags _                 = bWord dflags
isArgReg :: GlobalReg -> Bool
isArgReg (VanillaReg {}) = True
isArgReg (FloatReg {})   = True
isArgReg (DoubleReg {})  = True
isArgReg (LongReg {})    = True
isArgReg (XmmReg {})     = True
isArgReg (YmmReg {})     = True
isArgReg (ZmmReg {})     = True
isArgReg _               = False