-- | Atom rule scheduling.
module Language.Atom.Scheduling
  ( schedule
  ) where

import Control.Monad
import Data.List
import Language.Atom.Code
import Language.Atom.Elaboration
import System.IO
import Text.Printf

schedule :: Name -> [Rule] -> IO [[[Rule]]]
schedule name rules = do
  writeFile (name ++ ".rpt") $ reportSchedule periods
  return periods
  where

  periods = map spread $ zip [1..] $ map rulesWithPeriod [1 .. maxPeriod]  -- XXX No scheduling done.

  maxPeriod = maximum $ map rulePeriod rules

  rulesWithPeriod :: Int -> [Rule]
  rulesWithPeriod p = [ r | r <- rules, rulePeriod r == p ]

  spread :: (Int, [Rule]) -> [[Rule]]
  spread (0, []) = []
  spread (0, _)  = error "Scheduling.spread"
  spread (period, rules) = take rulesInCycle rules : spread (period - 1, drop rulesInCycle rules)
    where
    rulesInCycle = (length rules `div` period) + (if length rules `mod` period > 0 then 1 else 0)


reportSchedule :: [[[Rule]]] -> String
reportSchedule s = "Rule Scheduling Report\n\n  Period  Cycle  Exprs  Rule\n  ------  -----  -----  ----\n" ++ concatMap reportPeriod (zip [1..] s) ++
                   "                 -----\n" ++
                   printf "                 %5i\n" (sum $ map ruleComplexity (concat (concat s))) ++ "\n" ++
                   "Hierarchical Expression Count\n\n" ++
                   "  Total   Local     Rule\n" ++
                   "  ------  ------    ----\n" ++
                   reportUsage "" (usage $ concat $ concat s) ++
                   "\n"

reportPeriod (period, p) = concatMap (reportCycle period) (zip [0..] p)

reportCycle period (cycle, c) = concatMap (reportRule period cycle) c

reportRule :: Int -> Int -> Rule -> String
reportRule period cycle rule = printf "  %6i  %5i  %5i  %s\n" period cycle (ruleComplexity rule) (show rule)


data Usage = Usage String Int [Usage] deriving Eq

instance Ord Usage where compare (Usage a _ _) (Usage b _ _) = compare a b

reportUsage :: String -> Usage -> String
reportUsage i node@(Usage name n subs) = printf "  %6i  %6i    %s\n" (totalComplexity node) n (i ++ name) ++ concatMap (reportUsage ("  " ++ i)) subs

totalComplexity :: Usage -> Int
totalComplexity (Usage _ n subs) = n + sum (map totalComplexity subs)

usage :: [Rule] -> Usage
usage = head . foldl insertUsage [] . map usage'

usage' :: Rule -> Usage
usage' rule = f $ split $ ruleName rule
  where
  f :: [String] -> Usage
  f [] = undefined
  f [name] = Usage name (ruleComplexity rule) []
  f (name:names) = Usage name 0 [f names]

split :: String -> [String]
split "" = []
split s = a : if null b then [] else split (tail b) where (a,b) = span (/= '.') s

insertUsage :: [Usage] -> Usage -> [Usage]
insertUsage [] u = [u]
insertUsage (a@(Usage n1 i1 s1) : rest) b@(Usage n2 i2 s2) | n1 == n2  = Usage n1 (max i1 i2) (sort $ foldl insertUsage s1 s2) : rest
                                                           | otherwise = a : insertUsage rest b