{-# LANGUAGE BangPatterns #-}
module CSPM.Evaluator.Profiler (
    ProfilerOptions(..), defaultProfilerOptions,
    ProfilerState, initialProfilerState,
    maybeRegisterCall, registerCall,
    ProfilingData, getProfilingData,
    profilerActive,
) where

import Control.Concurrent.MVar
import Control.Exception (bracket)
import qualified Data.HashTable.IO as H
import Data.List ((\\), isPrefixOf, sort, sortBy, transpose)
import System.IO.Unsafe

import CSPM.DataStructures.Names
import CSPM.Evaluator.Monad
import Util.PrettyPrint

data ProfilerOptions = ProfilerOptions {
        isActive :: Bool,
        flattenRecursiveCalls :: Bool
    }

defaultProfilerOptions :: ProfilerOptions
defaultProfilerOptions = ProfilerOptions {
        isActive = False,
        flattenRecursiveCalls = True
    }

type CallCountTable = H.BasicHashTable Name Int
data ProfilerState = ProfilerState {
        options :: ProfilerOptions,
        tableLock :: MVar (),
        callCounts :: CallCountTable,
        hiearchicalCallCounts :: H.BasicHashTable [Name] CallCountTable,
        profilingStack :: [Name]
    }

profilerActive :: EvaluationState -> Bool
profilerActive st = isActive (options (profilerState st))

initialProfilerState :: ProfilerOptions -> IO ProfilerState
initialProfilerState options = do
    callCounts <- H.new
    hiearchicalCallCounts <- H.new
    tableLock <- newMVar ()
    return $! ProfilerState {
        options = options,
        tableLock = tableLock,
        callCounts = callCounts,
        hiearchicalCallCounts = hiearchicalCallCounts,
        profilingStack = []
    }

maybeRegisterCall ::
    EvaluationMonad (Name -> EvaluationMonad a -> EvaluationMonad a)
maybeRegisterCall = do
    profilerState <- gets profilerState
    return $!
        if isActive (options profilerState) then registerCall
        else \ _ prog -> prog

registerCall :: Name -> EvaluationMonad a -> EvaluationMonad a
registerCall n prog = do
    profilerState <- gets profilerState
    unsafePerformIO (incrementCounter profilerState n) `seq`
        modify (\ st -> st {
            profilerState = profilerState {
                profilingStack = 
                    let stk = profilingStack profilerState
                    in if flattenRecursiveCalls (options profilerState) then
                        case stk of
                            n':_ | n == n' -> stk
                            _ -> n : stk
                    else n : stk
            }
        }) prog
{-# NOINLINE registerCall #-}

incrementCallCounter :: CallCountTable -> Name -> IO ()
incrementCallCounter table n = do
    mcount <- H.lookup table n
    let !c = case mcount of
                Just count -> count + 1
                Nothing -> 1
    H.insert table n c

incrementCounter :: ProfilerState -> Name -> IO ()
incrementCounter profilerState n = do
    let lock = tableLock profilerState
        hierachicalTable = hiearchicalCallCounts profilerState
    bracket (takeMVar lock) (\_ -> putMVar lock ()) $ \_ -> do
        incrementCallCounter (callCounts profilerState) n

        let stk = profilingStack profilerState
        mtable <- H.lookup hierachicalTable stk
        table <- case mtable of
                    Just table -> return table
                    Nothing -> do
                        table <- H.new
                        H.insert hierachicalTable stk table
                        return table
        incrementCallCounter table n
{-# NOINLINE incrementCounter #-}

type CallCounts = [(Name, Int)]
data HierarchicalCallCount =
    HierarchicalCallCount [(Name, Int, HierarchicalCallCount)]
    deriving Show

createHierarchicalCallCount :: [([Name], CallCounts)] -> HierarchicalCallCount
createHierarchicalCallCount table =
    let reversedTable = map (\ (ns, c) -> (reverse ns, c)) table
        -- Sort the table lexiographically. This will cluster together
        -- common prefixes
        sortedTable = sortBy (\ (ns1, _) (ns2, _) -> compare ns1 ns2) reversedTable

        extractCommonPrefix :: [([Name], CallCounts)] ->
            (Name, [([Name], CallCounts)], [([Name], CallCounts)])
        extractCommonPrefix [] = error "invalid list"
        extractCommonPrefix (([], cs) : _) = error "invalid list"
        extractCommonPrefix ((n:ns1, cs1) : ncs) =
            let
                (common, rest) = span (\ (ns2, _) -> [n] `isPrefixOf` ns2) ncs
            in (n, (ns1, cs1) : (map (\ (ns, cs) -> (tail ns, cs)) common), rest)

        construct :: [([Name], CallCounts)] -> [(Name, Int, HierarchicalCallCount)]
        construct ncs =
            let
                -- How many time each top-level function was called
                counts = [(n, c) | ([], ncs') <- ncs, (n, c) <- ncs']

                recursiveCounts = [(xs, ncs') | (xs, ncs') <- ncs, xs /= []]

                -- Group into common prefixes
                commonPrefixes :: [([Name], CallCounts)] ->
                    [(Name, [([Name], CallCounts)])]
                commonPrefixes [] = []
                commonPrefixes ncs = (n, common) : commonPrefixes rest
                    where (n, common, rest) = extractCommonPrefix ncs

                allCommonPrefixes = commonPrefixes recursiveCounts

                safeLookup counts n =
                    case lookup n counts of
                        Just c -> c
                        Nothing -> -1

                constructCommonPrefix (n, sub) =
                    (n, safeLookup counts n, HierarchicalCallCount (construct sub))

                namesWithCount = map fst counts
                namesWithoutChildren =
                    sort namesWithCount \\ sort (map fst allCommonPrefixes)
            in map constructCommonPrefix allCommonPrefixes
                ++ map (\ n -> constructCommonPrefix (n, [])) namesWithoutChildren
    in HierarchicalCallCount $ construct sortedTable

data ProfilingData = ProfilingData {
        totalCallCounts :: CallCounts,
        hierachicalTable :: HierarchicalCallCount
    }

getProfilingData :: EvaluationMonad ProfilingData
getProfilingData = do
    profilerState <- gets profilerState
    return $! unsafePerformIO (extractProfilingData profilerState)
{-# NOINLINE getProfilingData #-}

extractProfilingData :: ProfilerState -> IO ProfilingData
extractProfilingData profilerState = do
    let lock = tableLock profilerState
    (tbl, htbl) <- bracket (takeMVar lock) (\_ -> putMVar lock ()) $ \_ -> do
        tbl <- H.toList (callCounts profilerState)
        htbl <- H.toList (hiearchicalCallCounts profilerState)
        htbl <- mapM (\ (n, table) -> do
            tbl <- H.toList table
            return (n, tbl)) htbl
        return (tbl, htbl)
    return $! ProfilingData tbl (createHierarchicalCallCount htbl)

printInColumns :: [Doc] -> [[Doc]] -> Doc
printInColumns header rows =
    let cols = transpose (map (map show) (header : rows))
        requiredWidth xs = maximum (map length xs)
        padColumn :: Char -> [String] -> [String]
        padColumn c xs = map (\ x -> x ++ replicate (len - length x) c) xs
            where len = requiredWidth xs
        joinColumns :: [[String]] -> [String]
        joinColumns cols = map concatWithSpace (transpose cols)
            where
                concatWithSpace [] = []
                concatWithSpace [x] = x
                concatWithSpace (x:xs) = x ++ " " ++ concatWithSpace xs
        insertSpacer (header:rs) = header : replicate (length header) '-' : rs
    in vcat (map text (insertSpacer (joinColumns (map (padColumn ' ') cols))))

prettyPrintOverallTable :: [(Name, Int)] -> Doc
prettyPrintOverallTable tbl =
    let descendingCounts (_, c1) (_, c2) = compare c2 c1
        totalCallCounts = map (\ (n, c) -> [prettyPrint n, int c]) $!
            sortBy descendingCounts tbl
    in printInColumns [text "Name", text "Call Count"] (totalCallCounts)

prettyPrintHierarchicalData :: HierarchicalCallCount -> Doc
prettyPrintHierarchicalData cs =
    let callCounts :: Int -> HierarchicalCallCount -> [[Doc]]
        callCounts depth (HierarchicalCallCount ncs) =
            concatMap (\ (n, count, rec) ->
                [text (replicate depth ' ') <> prettyPrint n, int count]
                : (callCounts (depth + 2)) rec
                )
            $ sortBy (\ (_, c1, _) (_, c2, _) -> compare c2 c1) ncs
    in printInColumns [text "Name", text "Call Count"] (callCounts 0 cs)

instance PrettyPrintable ProfilingData where
    prettyPrint (ProfilingData tbl hierarchicalData) =
        text "Total Function Call Counts"
        $$ text "--------------------------"
        $$ tabIndent (prettyPrintOverallTable tbl)
        $$ text " "
        $$ text "Hierarchical Call Counts"
        $$ text "------------------------"
        $$ tabIndent (prettyPrintHierarchicalData hierarchicalData)