-- | Atom code generation.
module Language.Atom.Code
  ( writeC
  , ruleComplexity
  ) where

import Data.Char
import Data.List
import Data.Maybe
import System.IO

import Language.Atom.Elaboration
import Language.Atom.Expressions

declareMemory :: UV -> String
declareMemory (UV id name (Local init)) = cType (constType init) ++ " " ++ e id ++ " = " ++ c ++ ";  /* " ++ name ++ " */\n"
  where
  c = case init of
    CBool   c -> if c then "1" else "0" 
    CInt8   c -> show c
    CInt16  c -> show c
    CInt32  c -> show c
    CInt64  c -> show c
    CWord8  c -> show c
    CWord16 c -> show c
    CWord32 c -> show c
    CWord64 c -> show c
    CFloat  c -> show c
    CDouble c -> show c
declareMemory (UV _ _ (External _)) = ""

declareUE :: String -> (UE, String) -> String
declareUE d (ue, n) = case ue of
  UVRef _              -> ""
  UConst (CBool True ) -> d ++ "const " ++ cType (ueType ue) ++ " " ++ n ++ " = 1;\n"
  UConst (CBool False) -> d ++ "const " ++ cType (ueType ue) ++ " " ++ n ++ " = 0;\n"
  UConst c             -> d ++ "const " ++ cType (ueType ue) ++ " " ++ n ++ " = " ++ show c ++ ";\n"
  _                    -> d ++             cType (ueType ue) ++ " " ++ n ++ ";\n"

cType :: Type -> String
cType t = case t of
  Bool   -> "unsigned char"
  Int8   -> "signed char"
  Int16  -> "signed short int"
  Int32  -> "signed long int"
  Int64  -> "signed long long int"
  Word8  -> "unsigned char"
  Word16 -> "unsigned short int"
  Word32 -> "unsigned long int"
  Word64 -> "unsigned long long int"
  Float  -> "float"
  Double -> "double"

codeUE :: [(UE, String)] -> String -> (UE, String) -> String
codeUE ues d (ue, n) = case ue of
  UConst _       -> ""
  UVRef _        -> ""
  _              -> d ++ n ++ " = " ++ basic operands ++ ";\n"
  where
  operands = map (fromJust . flip lookup ues) $ ueUpstream ue
  basic :: [String] -> String
  basic operands = case ue of
    UVRef _              -> error "Code.ueStmt: should not get here."
    UCast _ _            -> "(" ++ cType (ueType ue) ++ ") " ++ a
    UConst _             -> error "Code.ueStmt: should not get here."
    UAdd _ _             ->  a ++ " + " ++ b
    USub _ _             ->  a ++ " - " ++ b
    UMul _ _             ->  a ++ " * " ++ b
    UDiv _ _             ->  a ++ " / " ++ b
    UMod _ _             ->  a ++ " % " ++ b
    UNot _               ->  "! " ++ a
    UAnd _               ->  drop 4 $ concat [ " && " ++ a | a <- operands ]
    UBWNot _             ->  "~ " ++ a
    UBWAnd _ _           ->  a ++ " & " ++ b
    UBWOr  _ _           ->  a ++ " | " ++ b
    UShift _ n           -> (if n >= 0 then a ++ " << " ++ show n else a ++ " >> " ++ show (0 - n))
    UEq  _ _             -> a ++ " == " ++ b
    ULt  _ _             -> a ++ " < " ++ b
    UMux _ _ _           -> a ++ " ? " ++ b ++ " : " ++ c
    UF2B _               -> "*((unsigned long int *) &" ++ a ++ ")"
    UD2B _               -> "*((unsigned long long int *) &" ++ a ++ ")"
    UB2F _               -> "*((float *) &" ++ a ++ ")"
    UB2D _               -> "*((double *) &" ++ a ++ ")"
    where
    a = operands !! 0
    b = operands !! 1
    c = operands !! 2

writeC :: Name -> [[[Rule]]] -> [UV] -> IO ()
writeC name periods uvs = do
  writeFile (name ++ ".c") c
  writeFile "CoverageData.hs" cov
  where
  c = unlines
    [ cType Word64 ++ " __clock = 0;"
    , "const " ++ cType Word32 ++ " __coverage_len = " ++ show covLen ++ ";"
    , cType Word32 ++ " __coverage[" ++ show covLen ++ "] = {" ++ drop 2 (concat $ replicate covLen ", 0") ++ "};"
    , cType Word32 ++ " __coverage_index = 0;"
    , concatMap declareMemory uvs
    , concatMap (codeRule topo') $ concat $ concat periods
    , "void " ++ name ++ "(void) {"
    , concatMap codePeriod $ zip [1..] periods
    , "  __clock = __clock + 1;"
    , "}"
    ]

  rules = concat $ concat periods

  cov = unlines
    [ "module CoverageData (coverageData) where"
    , ""
    , "-- | Encoding of rule coverage: (rule name, coverage array index, coverage bit)"
    , "coverageData :: [(String, (Int, Int))]"
    , "coverageData = " ++ show [ (ruleName r, (div (ruleId r) 32, mod (ruleId r) 32)) | r <- rules ]
    ]

  topo' = topo $ maximum (map (\ (UV i _ _) -> i) uvs) + 1
  covLen = 1 + div (maximum $ map ruleId rules) 32

codeRule :: ([UE] -> [(UE, String)]) -> Rule -> String
codeRule topo rule = 
  "/* " ++ show rule ++ " */\n" ++
  "void r" ++ show (ruleId rule) ++ "(void) {\n" ++
  concatMap (declareUE  "  ") ues ++
  concatMap (codeUE ues "  ") ues ++
  "  if (" ++ id (ruleEnable rule) ++ ") {\n" ++
  concatMap codeAction (ruleActions rule) ++
  "    __coverage[" ++ covWord ++ "] = __coverage[" ++ covWord ++ "] | (1 << " ++ covBit ++ ");\n" ++
  "  }\n" ++
  concatMap (\ (uv, ue) -> "  " ++ v uv ++ " = " ++ id ue ++ ";\n") (ruleAssigns rule) ++
  "}\n\n"
  where
  ues = topo $ ruleEnable rule : snd (unzip (ruleAssigns rule)) ++ concat (snd (unzip (ruleActions rule)))
  id ue = fromJust $ lookup ue ues
  codeAction :: (([String] -> String), [UE]) -> String
  codeAction (f, args) = "    " ++ f (map id args) ++ ";\n"
  covWord = show $ div (ruleId rule) 32
  covBit  = show $ mod (ruleId rule) 32

codePeriod :: (Int, [[Rule]]) -> String
codePeriod (period, cycles) = concatMap (codeCycle period) $ zip [0..] cycles

codeCycle :: Int -> (Int, [Rule]) -> String
codeCycle period (cycle, _) | cycle >= period = error "Code.codeCycle"
codeCycle _ (_, rules)      | null rules = ""
codeCycle period (cycle, rules) =
  "  if (__clock % " ++ show period ++ " == " ++ show cycle ++ ") {\n" ++
  concatMap (\ r -> "    r" ++ show (ruleId r) ++ "();  /* " ++ show r ++ " */\n") rules ++
  "  }\n"

e :: Int -> String
e i = "e" ++ show i

v :: UV -> String
v (UV i _ (Local _)) = e i
v (UV _ n (External _)) = n

-- | Topologically sorts a list of expressions and subexpressions.
topo :: Int -> [UE] -> [(UE, String)]
topo start ues = reverse ues'
  where
  (_, ues') = foldl collect (start, []) ues
  collect :: (Int, [(UE, String)]) -> UE -> (Int, [(UE, String)])
  collect (n, ues) ue | any ((== ue) . fst) ues = (n, ues)
  collect (n, ues) ue = case ue of
    UVRef (UV i _ (Local    _)) -> (n, (ue, e i) : ues)
    UVRef (UV _ a (External _)) -> (n, (ue, a)   : ues)
    _                -> (n' + 1, (ue, e n') : ues')
      where
      (n', ues') = foldl collect (n, ues) $ ueUpstream ue

-- | Number of UE's computed in rule.
ruleComplexity :: Rule -> Int
ruleComplexity rule = length $ topo 0 $ ruleEnable rule : snd (unzip (ruleAssigns rule)) ++ concat (snd (unzip (ruleActions rule)))