-- | Atom rule scheduling.
module Language.Atom.Scheduling
  ( schedule
  , Schedule
  , reportSchedule
  ) where

import Data.List
import Language.Atom.Analysis
import Language.Atom.Elaboration
import Text.Printf

type Schedule = [(Int, Int, [Rule])]  -- (period, phase, rules)

schedule :: [Rule] -> Schedule
schedule rules = concatMap spread periods
  where

  spread :: (Int, [Rule]) -> Schedule
  spread (period, rules) = [ (period, phase, rules) | (phase, rules) <- zip [0..] $ unconcat rulesPerPhase rules ]
    where
    rulesPerPhase = (length rules `div` period) + (if length rules `mod` period > 0 then 1 else 0)

  periods = foldl grow [] [ (rulePeriod r, r) | r <- rules ]

  grow :: [(Int, [Rule])] -> (Int, Rule) -> [(Int, [Rule])]
  grow [] (a, b) = [(a, [b])]
  grow ((a, bs):rest) (a', b) | a' == a   = (a, b : bs) : rest
                              | otherwise = (a, bs) : grow rest (a', b)

unconcat :: Int -> [a] -> [[a]]
unconcat _ [] = []
unconcat n a  = take n a : unconcat n (drop n a)

reportSchedule :: Schedule -> String
reportSchedule schedule = concat
  [ "Rule Scheduling Report\n\n"
  , "Period  Phase  Exprs  Rule\n"
  , "------  -----  -----  ----\n"
  , concatMap reportPeriod schedule
  , "               -----\n"
  , printf "               %5i\n" $ sum $ map ruleComplexity rules
  , "\n"
  , "Hierarchical Expression Count\n\n"
  , "  Total   Local     Rule\n"
  , "  ------  ------    ----\n"
  , reportUsage "" $ usage rules
  , "\n"
  ]
  where
  rules = concat $ [ r | (_, _, r) <- schedule ]


reportPeriod :: (Int, Int, [Rule]) -> String
reportPeriod (period, phase, rules) = concatMap reportRule rules
  where
  reportRule :: Rule -> String
  reportRule rule = printf "%6i  %5i  %5i  %s\n" period phase (ruleComplexity rule) (show rule)


data Usage = Usage String Int [Usage] deriving Eq

instance Ord Usage where compare (Usage a _ _) (Usage b _ _) = compare a b

reportUsage :: String -> Usage -> String
reportUsage i node@(Usage name n subs) = printf "  %6i  %6i    %s\n" (totalComplexity node) n (i ++ name) ++ concatMap (reportUsage ("  " ++ i)) subs

totalComplexity :: Usage -> Int
totalComplexity (Usage _ n subs) = n + sum (map totalComplexity subs)

usage :: [Rule] -> Usage
usage = head . foldl insertUsage [] . map usage'

usage' :: Rule -> Usage
usage' rule = f $ split $ ruleName rule
  where
  f :: [String] -> Usage
  f [] = undefined
  f [name] = Usage name (ruleComplexity rule) []
  f (name:names) = Usage name 0 [f names]

split :: String -> [String]
split "" = []
split s = a : if null b then [] else split (tail b) where (a,b) = span (/= '.') s

insertUsage :: [Usage] -> Usage -> [Usage]
insertUsage [] u = [u]
insertUsage (a@(Usage n1 i1 s1) : rest) b@(Usage n2 i2 s2) | n1 == n2  = Usage n1 (max i1 i2) (sort $ foldl insertUsage s1 s2) : rest
                                                           | otherwise = a : insertUsage rest b