module LinearScan.Hoopl.DSL where

import           Compiler.Hoopl as Hoopl hiding ((<*>))
import           Control.Applicative
import           Control.Arrow (first)
import           Control.Monad.Free
import           Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Free as TF
import           Control.Monad.Trans.Free hiding (FreeF(..), Free)
import           Control.Monad.Trans.State (StateT, evalStateT,
                                            gets, modify, get, put)
import qualified Data.Map as M
import           Data.Monoid
import           LinearScan

type Labels = M.Map String Label
type BlockIds = M.Map Label Int

data SpillStack = SpillStack
    { stackPtr      :: Int
    , stackSlotSize :: Int
    , stackSlots    :: M.Map (Maybe Int) Int
    }
    deriving (Eq, Show)

data EnvState = EnvState
    { envLabels      :: Labels
    , envBlockIds    :: BlockIds
    , envSpillStack  :: SpillStack
    , envAssignments :: M.Map PhysReg VarId
    }

type Env = StateT EnvState SimpleUniqueMonad

newSpillStack :: Int -> Int -> SpillStack
newSpillStack offset slotSize = SpillStack
    { stackPtr      = offset
    , stackSlotSize = slotSize
    , stackSlots    = mempty
    }

newEnvState :: Int -> Int -> EnvState
newEnvState offset slotSize = EnvState
    { envLabels      = mempty
    , envBlockIds    = mempty
    , envSpillStack  = newSpillStack offset slotSize
    , envAssignments = mempty
    }

getStackSlot :: Maybe VarId -> Env Int
getStackSlot vid = do
    st <- get
    let stack = envSpillStack st
    case M.lookup vid (stackSlots stack) of
        Just off -> return off
        Nothing -> do
            let off = stackPtr stack
            put st { envSpillStack = stack
                 { stackPtr   = off + stackSlotSize stack
                 , stackSlots =
                     M.insert vid off (stackSlots stack)
                 }}
            return off

-- | The 'Asm' monad lets us create labels by name and refer to them later.
type Asm = StateT Labels SimpleUniqueMonad

getLabel :: String -> Asm Label
getLabel str = do
    l <- gets (M.lookup str)
    case l of
        Just lbl -> return lbl
        Nothing -> do
            lbl <- lift freshLabel
            modify (M.insert str lbl)
            return lbl

setAssignment :: PhysReg -> VarId -> Env ()
setAssignment reg vid =
    modify $ \env ->
        env { envAssignments = M.insert reg vid (envAssignments env) }

getAssignment :: PhysReg -> Env VarId
getAssignment reg = do
    l <- gets (M.lookup reg . envAssignments)
    case l of
        Just vid -> return vid
        Nothing  -> error $ "No assignment for register: " ++ show reg

-- | A series of 'Nodes' is a set of assembly instructions that ends with some
--   kind of closing operation, such as a jump, branch or return.
type Nodes n a = Free ((,) (n O O)) a

-- | The 'Nodes' free monad is really just a convenient way to describe a list
--   that must result in a closing operation at the end.
nodesToList :: Nodes n a -> (a, [n O O])
nodesToList (Pure a) = (a, [])
nodesToList (Free (n, xs)) = (n :) <$> nodesToList xs

type BodyNode n = Nodes n ()

bodyNode :: n O O -> BodyNode n
bodyNode n = Free (n, Pure ())

type EndNode n = Nodes n (Asm (n O C))

endNode :: Asm (n O C) -> EndNode n
endNode = return

-- | A program is a series of 'Nodes', each associated with a label.
data ProgramF n = FreeBlock
    { labelEntry :: Label
    , labelBody  :: EndNode n
    }
type Program n = FreeT ((,) (ProgramF n)) Asm ()

label :: String -> EndNode n -> Program n
label str body = do
    lbl <- lift $ getLabel str
    liftF (FreeBlock lbl body, ())

jump :: HooplNode n => String -> EndNode n
jump dest = endNode $ mkBranchNode <$> getLabel dest

-- | When we compile a program, the result is a closed Hoopl Graph and the
--   label corresponding to the requested entry label name.
compile :: (NonLocal n, HooplNode n)
        => String -> Program n -> SimpleUniqueMonad (Graph n C C, Label)
compile name prog
    = flip evalStateT (mempty :: Labels)
    $ do body  <- go prog
         entry <- gets (M.lookup name)
         case entry of
             Nothing  -> error $ "Missing label: " ++ name
             Just lbl -> return (bodyGraph body, lbl)
  where
    go m = do
        p <- runFreeT m
        case p of
            TF.Pure () -> return emptyBody
            TF.Free (blk, xs) -> addBlock <$> comp blk <*> go xs

    comp (FreeBlock lbl body) = do
        let (close, blocks) = nodesToList body
        BlockCC (mkLabelNode lbl) (blockFromList blocks) <$> close