{-# LANGUAGE GADTs #-}
module CmmImplementSwitchPlans
  ( cmmImplementSwitchPlans
  )
where

import GhcPrelude

import Hoopl.Block
import BlockId
import Cmm
import CmmUtils
import CmmSwitch
import UniqSupply
import DynFlags

--
-- This module replaces Switch statements as generated by the Stg -> Cmm
-- transformation, which might be huge and sparse and hence unsuitable for
-- assembly code, by proper constructs (if-then-else trees, dense jump tables).
--
-- The actual, abstract strategy is determined by createSwitchPlan in
-- CmmSwitch and returned as a SwitchPlan; here is just the implementation in
-- terms of Cmm code. See Note [Cmm Switches, the general plan] in CmmSwitch.
--
-- This division into different modules is both to clearly separate concerns,
-- but also because createSwitchPlan needs access to the constructors of
-- SwitchTargets, a data type exported abstractly by CmmSwitch.
--

-- | Traverses the 'CmmGraph', making sure that 'CmmSwitch' are suitable for
-- code generation.
cmmImplementSwitchPlans :: DynFlags -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans :: DynFlags -> CmmGraph -> UniqSM CmmGraph
cmmImplementSwitchPlans dflags :: DynFlags
dflags g :: CmmGraph
g
    | HscTarget -> Bool
targetSupportsSwitch (DynFlags -> HscTarget
hscTarget DynFlags
dflags) = CmmGraph -> UniqSM CmmGraph
forall (m :: * -> *) a. Monad m => a -> m a
return CmmGraph
g
    | Bool
otherwise = do
    [CmmBlock]
blocks' <- [[CmmBlock]] -> [CmmBlock]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[CmmBlock]] -> [CmmBlock])
-> UniqSM [[CmmBlock]] -> UniqSM [CmmBlock]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` (CmmBlock -> UniqSM [CmmBlock])
-> [CmmBlock] -> UniqSM [[CmmBlock]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DynFlags -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches DynFlags
dflags) (CmmGraph -> [CmmBlock]
toBlockList CmmGraph
g)
    CmmGraph -> UniqSM CmmGraph
forall (m :: * -> *) a. Monad m => a -> m a
return (CmmGraph -> UniqSM CmmGraph) -> CmmGraph -> UniqSM CmmGraph
forall a b. (a -> b) -> a -> b
$ BlockId -> [CmmBlock] -> CmmGraph
ofBlockList (CmmGraph -> BlockId
forall (n :: * -> * -> *). GenCmmGraph n -> BlockId
g_entry CmmGraph
g) [CmmBlock]
blocks'

visitSwitches :: DynFlags -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches :: DynFlags -> CmmBlock -> UniqSM [CmmBlock]
visitSwitches dflags :: DynFlags
dflags block :: CmmBlock
block
  | (entry :: CmmNode C O
entry@(CmmEntry _ scope :: CmmTickScope
scope), middle :: Block CmmNode O O
middle, CmmSwitch expr :: CmmExpr
expr ids :: SwitchTargets
ids) <- CmmBlock -> (CmmNode C O, Block CmmNode O O, CmmNode O C)
forall (n :: * -> * -> *).
Block n C C -> (n C O, Block n O O, n O C)
blockSplit CmmBlock
block
  = do
    let plan :: SwitchPlan
plan = SwitchTargets -> SwitchPlan
createSwitchPlan SwitchTargets
ids

    (newTail :: Block CmmNode O C
newTail, newBlocks :: [CmmBlock]
newBlocks) <- DynFlags
-> CmmTickScope
-> CmmExpr
-> SwitchPlan
-> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan DynFlags
dflags CmmTickScope
scope CmmExpr
expr SwitchPlan
plan

    let block' :: CmmBlock
block' = CmmNode C O
entry CmmNode C O -> Block CmmNode O O -> Block CmmNode C O
forall (n :: * -> * -> *) x. n C O -> Block n O x -> Block n C x
`blockJoinHead` Block CmmNode O O
middle Block CmmNode C O -> Block CmmNode O C -> CmmBlock
forall (n :: * -> * -> *) e x.
Block n e O -> Block n O x -> Block n e x
`blockAppend` Block CmmNode O C
newTail

    [CmmBlock] -> UniqSM [CmmBlock]
forall (m :: * -> *) a. Monad m => a -> m a
return ([CmmBlock] -> UniqSM [CmmBlock])
-> [CmmBlock] -> UniqSM [CmmBlock]
forall a b. (a -> b) -> a -> b
$ CmmBlock
block' CmmBlock -> [CmmBlock] -> [CmmBlock]
forall a. a -> [a] -> [a]
: [CmmBlock]
newBlocks

  | Bool
otherwise
  = [CmmBlock] -> UniqSM [CmmBlock]
forall (m :: * -> *) a. Monad m => a -> m a
return [CmmBlock
block]


-- Implementing a switch plan (returning a tail block)
implementSwitchPlan :: DynFlags -> CmmTickScope -> CmmExpr -> SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan :: DynFlags
-> CmmTickScope
-> CmmExpr
-> SwitchPlan
-> UniqSM (Block CmmNode O C, [CmmBlock])
implementSwitchPlan dflags :: DynFlags
dflags scope :: CmmTickScope
scope expr :: CmmExpr
expr = SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go
  where
    go :: SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go (Unconditionally l :: BlockId
l)
      = (Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O O
forall (n :: * -> * -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: * -> * -> *) e. Block n e O -> n O C -> Block n e C
`blockJoinTail` BlockId -> CmmNode O C
CmmBranch BlockId
l, [])
    go (JumpTable ids :: SwitchTargets
ids)
      = (Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O O
forall (n :: * -> * -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: * -> * -> *) e. Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmExpr -> SwitchTargets -> CmmNode O C
CmmSwitch CmmExpr
expr SwitchTargets
ids, [])
    go (IfLT signed :: Bool
signed i :: Integer
i ids1 :: SwitchPlan
ids1 ids2 :: SwitchPlan
ids2)
      = do
        (bid1 :: BlockId
bid1, newBlocks1 :: [CmmBlock]
newBlocks1) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids1
        (bid2 :: BlockId
bid2, newBlocks2 :: [CmmBlock]
newBlocks2) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids2

        let lt :: DynFlags -> CmmExpr -> CmmExpr -> CmmExpr
lt | Bool
signed    = DynFlags -> CmmExpr -> CmmExpr -> CmmExpr
cmmSLtWord
               | Bool
otherwise = DynFlags -> CmmExpr -> CmmExpr -> CmmExpr
cmmULtWord
            scrut :: CmmExpr
scrut = DynFlags -> CmmExpr -> CmmExpr -> CmmExpr
lt DynFlags
dflags CmmExpr
expr (CmmExpr -> CmmExpr) -> CmmExpr -> CmmExpr
forall a b. (a -> b) -> a -> b
$ CmmLit -> CmmExpr
CmmLit (CmmLit -> CmmExpr) -> CmmLit -> CmmExpr
forall a b. (a -> b) -> a -> b
$ DynFlags -> Integer -> CmmLit
mkWordCLit DynFlags
dflags Integer
i
            lastNode :: CmmNode O C
lastNode = CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmNode O C
CmmCondBranch CmmExpr
scrut BlockId
bid1 BlockId
bid2 Maybe Bool
forall a. Maybe a
Nothing
            lastBlock :: Block CmmNode O C
lastBlock = Block CmmNode O O
forall (n :: * -> * -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: * -> * -> *) e. Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmNode O C
lastNode
        (Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O C
lastBlock, [CmmBlock]
newBlocks1[CmmBlock] -> [CmmBlock] -> [CmmBlock]
forall a. [a] -> [a] -> [a]
++[CmmBlock]
newBlocks2)
    go (IfEqual i :: Integer
i l :: BlockId
l ids2 :: SwitchPlan
ids2)
      = do
        (bid2 :: BlockId
bid2, newBlocks2 :: [CmmBlock]
newBlocks2) <- SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' SwitchPlan
ids2

        let scrut :: CmmExpr
scrut = DynFlags -> CmmExpr -> CmmExpr -> CmmExpr
cmmNeWord DynFlags
dflags CmmExpr
expr (CmmExpr -> CmmExpr) -> CmmExpr -> CmmExpr
forall a b. (a -> b) -> a -> b
$ CmmLit -> CmmExpr
CmmLit (CmmLit -> CmmExpr) -> CmmLit -> CmmExpr
forall a b. (a -> b) -> a -> b
$ DynFlags -> Integer -> CmmLit
mkWordCLit DynFlags
dflags Integer
i
            lastNode :: CmmNode O C
lastNode = CmmExpr -> BlockId -> BlockId -> Maybe Bool -> CmmNode O C
CmmCondBranch CmmExpr
scrut BlockId
bid2 BlockId
l Maybe Bool
forall a. Maybe a
Nothing
            lastBlock :: Block CmmNode O C
lastBlock = Block CmmNode O O
forall (n :: * -> * -> *). Block n O O
emptyBlock Block CmmNode O O -> CmmNode O C -> Block CmmNode O C
forall (n :: * -> * -> *) e. Block n e O -> n O C -> Block n e C
`blockJoinTail` CmmNode O C
lastNode
        (Block CmmNode O C, [CmmBlock])
-> UniqSM (Block CmmNode O C, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (Block CmmNode O C
lastBlock, [CmmBlock]
newBlocks2)

    -- Same but returning a label to branch to
    go' :: SwitchPlan -> UniqSM (BlockId, [CmmBlock])
go' (Unconditionally l :: BlockId
l)
      = (BlockId, [CmmBlock]) -> UniqSM (BlockId, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (BlockId
l, [])
    go' p :: SwitchPlan
p
      = do
        BlockId
bid <- Unique -> BlockId
mkBlockId (Unique -> BlockId) -> UniqSM Unique -> UniqSM BlockId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` UniqSM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
        (last :: Block CmmNode O C
last, newBlocks :: [CmmBlock]
newBlocks) <- SwitchPlan -> UniqSM (Block CmmNode O C, [CmmBlock])
go SwitchPlan
p
        let block :: CmmBlock
block = BlockId -> CmmTickScope -> CmmNode C O
CmmEntry BlockId
bid CmmTickScope
scope CmmNode C O -> Block CmmNode O C -> CmmBlock
forall (n :: * -> * -> *) x. n C O -> Block n O x -> Block n C x
`blockJoinHead` Block CmmNode O C
last
        (BlockId, [CmmBlock]) -> UniqSM (BlockId, [CmmBlock])
forall (m :: * -> *) a. Monad m => a -> m a
return (BlockId
bid, CmmBlock
blockCmmBlock -> [CmmBlock] -> [CmmBlock]
forall a. a -> [a] -> [a]
: [CmmBlock]
newBlocks)