-- | Symbolic STG Execution Engine
module SSTG.Core.Execution.Engine
    ( loadState
    , loadStateEntry
    , LoadResult (..)
    , RunFlags (..)
    , StepType (..)
    , execute
    , execute1
    ) where

import SSTG.Core.Language

import SSTG.Core.Execution.Stepping
import SSTG.Core.Execution.Support

-- | Load Result
data LoadResult = LoadOkay State
                | LoadGuess State [Binds]
                | LoadError String
                deriving (Show, Eq, Read)

-- | Guess the main function as @"main"@, which is consistent with a few
-- experimental results.
loadState :: Program -> LoadResult
loadState prog = loadStateEntry main_occ_name prog
  where
    main_occ_name = "main"  -- Based on a few experimental programs.

-- | Load from a specified entry point.
loadStateEntry :: String -> Program -> LoadResult
loadStateEntry entry (Program bindss) = if length matches == 0
    then LoadError ("No entry candidates found for: [" ++ entry ++ "]")
    else if length others == 0
        then LoadOkay state
        else LoadGuess state (map fst others)
  where
    -- Status or something.
    status = init_status
    -- Stack initialized to empty.
    stack = empty_stack
    -- Globals and Heap are loaded together. They are still beta forms now.
    heap0 = empty_heap
    (glist, heap1, binds_addrss) = initGlobals bindss heap0
    globals0 = insertGlobalsVals glist empty_globals
    (heap2, localss) = liftBindsAddrss binds_addrss globals0 heap1
    binds_locs = zip bindss localss
    -- Code loading. Completes heap and globals with symbolic injection.
    matches = entryMatches entry binds_locs
    ((tgt_binds, tgt_loc):others) = matches
    ((tgt_var, tgt_rhs):_) = lhsMatches entry tgt_binds
    (code, globals, heap) = loadCode tgt_var tgt_rhs tgt_loc globals0 heap2
    -- Ready to fill the state.
    state0 = State { state_status = status
                   , state_stack = stack
                   , state_heap = heap
                   , state_globals = globals
                   , state_code = code
                   , state_names = []
                   , state_path = empty_pathcons }

    -- Gather information on all variables.
    state = state0 { state_names = allNames (Program bindss) }

-- | Allocate Binds
allocBinds :: Binds -> Heap -> (Heap, [MemAddr])
allocBinds (Binds _ kvs) heap = (heap', addrs)
  where
    hfakes = map (const Blackhole) kvs
    (heap', addrs) = allocHeapObjs hfakes heap

-- | Allocate List of `Binds`s
allocBindss :: [Binds] -> Heap -> (Heap, [[MemAddr]])
allocBindss [] heap = (heap, [])
allocBindss (bind:bs) heap = (heapf, addrs : as)
  where
    (heap', addrs) = allocBinds bind heap
    (heapf, as) = allocBindss bs heap'

-- | Binds Address to Name Vals
bindsAddrsToVarVals :: (Binds, [MemAddr]) -> [(Var, Val)]
bindsAddrsToVarVals (Binds _ kvs, addrs) = zip (map fst kvs) mem_vals
  where
    mem_vals = map (\a -> MemVal a) addrs

-- | Initialize Globals
initGlobals :: [Binds] -> Heap -> ([(Var, Val)], Heap, [(Binds, [MemAddr])])
initGlobals bindss heap = (var_vals, heap', binds_addrss)
  where
    (heap', addrss) = allocBindss bindss heap
    binds_addrss = zip bindss addrss
    var_vals = concatMap bindsAddrsToVarVals binds_addrss

-- | Force Atom Lookup
forceLookupVal :: Atom -> Locals -> Globals -> Val
forceLookupVal (LitAtom lit) _ _ = LitVal lit
forceLookupVal (VarAtom var) locals globals =
    case lookupVal var locals globals of
        Just val -> val
        Nothing -> LitVal BlankAddr  -- An error, but I want to not crash.

-- | Full Rhs Object
forceRhsObj :: BindRhs -> Locals -> Globals -> HeapObj
forceRhsObj (FunForm prms expr) locals _ = FunObj prms expr locals
forceRhsObj (ConForm dcon args) locals globals = ConObj dcon arg_vals
  where
    arg_vals = map (\a -> forceLookupVal a locals globals) args

-- | Lift `Binds`.
liftBindsAddrs :: (Binds, [MemAddr]) -> Globals -> Heap -> (Heap, Locals)
liftBindsAddrs (Binds rec kvs, addrs) globals heap = (heap', locals)
  where
    (vars, rhss) = unzip kvs
    mem_vals = map (\a -> MemVal a) addrs
    e_locs = empty_locals
    r_locs = insertLocalsVals (zip vars mem_vals) e_locs
    locals = case rec of { Rec -> r_locs; NonRec -> e_locs }
    hobjs = map (\r -> forceRhsObj r locals globals) rhss
    heap' = insertHeapObjs (zip addrs hobjs) heap

-- | Lift Binds List
liftBindsAddrss :: [(Binds, [MemAddr])] -> Globals -> Heap -> (Heap, [Locals])
liftBindsAddrss [] _ heap = (heap, [])
liftBindsAddrss (bind_addr:bms) globals heap = (heapf, locals : ls)
  where
    (heap', locals) = liftBindsAddrs bind_addr globals heap
    (heapf, ls) = liftBindsAddrss bms globals heap'

-- | Return a sub-list of binds in which the entry candidate appears.
entryMatches :: String -> [(Binds, Locals)] -> [(Binds, Locals)]
entryMatches entry binds_locs = filter (isEntryBinds entry) binds_locs

-- | Binds Filtering
isEntryBinds :: String -> (Binds, Locals) -> Bool
isEntryBinds entry (binds, _) = lhsMatches entry binds /= []

-- | Sub-Bindss String Match
lhsMatches :: String -> Binds -> [(Var, BindRhs)]
lhsMatches st (Binds _ kvs) =
    filter (\(var, _) -> st == (nameOccStr . varName) var) kvs

-- | Load Code
loadCode :: Var -> BindRhs -> Locals -> Globals -> Heap -> (Code,Globals,Heap)
loadCode ent (ConForm _ _) locals globals heap = (code, globals, heap)
  where
    code = Evaluate (Atom (VarAtom ent)) locals
loadCode ent (FunForm params expr) locals globals heap = (code,globals,heap')
  where
    actuals = traceArgs params expr locals globals heap
    confs = map varName actuals
    names' = freshSeededNames confs confs
    adjusted = map (\(n, t) -> Var n t) (zip names' (map typeOf actuals))
    -- Throw the parameters on heap as symbolic objects
    sym_objs = map (\p -> SymObj (Symbol p Nothing)) adjusted
    (heap', addrs) = allocHeapObjs sym_objs heap
    -- make Atom representations for arguments and shove into locals.
    mem_vals = map (\a -> MemVal a) addrs
    locals' = insertLocalsVals (zip adjusted mem_vals) locals
    args = map (\p -> VarAtom p) adjusted
    -- Set up code
    code = Evaluate (FunApp ent args) locals'

-- | We need to do stupid tracing if it's THUNK'D by default >:(
traceArgs :: [Var] -> Expr -> Locals -> Globals -> Heap -> [Var]
traceArgs base expr locals globals heap
  | FunApp var [] <- expr
  , Just (_, hobj) <- vlookupHeap var locals globals heap
  , FunObj params _ _ <- hobj
  , length params > 0
  , length base == 0 = params

  | otherwise = base

-- | Run flags.
data RunFlags = RunFlags { flag_step_count :: Int
                         , flag_step_type :: StepType
                         , flag_dump_dir :: Maybe FilePath
                         } deriving (Show, Eq, Read)

-- | Step execution type.
data StepType = BFS | DFS | BFSLogged | DFSLogged deriving (Show, Eq, Read)

-- | Perform execution on a `State` given the run flags.
execute :: RunFlags -> State -> [([LiveState], [DeadState])]
execute flags state = step (flag_step_count flags) state
  where
    step :: Int -> State -> [([LiveState], [DeadState])]
    step = case flag_step_type flags of
               BFS -> \k s -> [runBoundedBFS k s]
               BFSLogged -> runBoundedBFSLogged
               DFS -> \k s -> [runBoundedDFS k s]
               DFSLogged -> runBoundedDFSLogged

-- | Simple `BFS` based execution on a state.
execute1 :: Int -> State -> ([LiveState], [DeadState])
execute1 n state | n < 1 = ([([], state)], [])
                 | otherwise = runBoundedBFS n state