-- | Atom C code generation.
module Language.Atom.Code
  ( Config (..)
  , writeC
  , ruleComplexity
  , defaults
  , cTypes
  , c99Types
  ) where

import Data.Char
import Data.List
import Data.Maybe
import Data.Word
import System.IO
import Unsafe.Coerce

import Language.Atom.Elaboration
import Language.Atom.Expressions

-- | C code configuration parameters.
data Config = Config
  {
    cFuncName     :: String          -- ^ Alternative primary function name.  Leave empty to use compile name.
  , cType         :: Type -> String  -- ^ C type naming rules.
  , cPreCode      :: String          -- ^ C code to insert above (includes, macros, etc.).
  , cPostCode     :: String          -- ^ C code to insert below (main, etc.).
  , cRuleCoverage :: Bool            -- ^ Enable rule coverage tracking.
  , cAssert       :: Bool            -- ^ Enable assertion checking.
  , cAssertName   :: String          -- ^ Name of assertion function.  Type: void assert(char*, cType Bool);
  , cCover        :: Bool            -- ^ Enable functional coverage accumulation.
  , cCoverName    :: String          -- ^ Name of coverage function.  Type: void cover(char*, cType Bool);
  }

-- | Default C code configuration parameters (default function name, no pre/post code, ANSI C types).
defaults :: Config
defaults = Config
  { cFuncName     = ""
  , cType         = cTypes
  , cPreCode      = ""
  , cPostCode     = ""
  , cRuleCoverage = True
  , cAssert       = True
  , cAssertName   = "assert"
  , cCover        = True
  , cCoverName    = "cover"
  }

showConst :: Const -> String
showConst c = case c of
  CBool   c -> if c then "1" else "0"
  CInt8   c -> show c
  CInt16  c -> show c
  CInt32  c -> show c ++ "L"
  CInt64  c -> show c ++ "LL"
  CWord8  c -> show c
  CWord16 c -> show c
  CWord32 c -> show c ++ "UL"
  CWord64 c -> show c ++ "ULL"
  CFloat  c -> show c
  CDouble c -> show c

-- | ANSI C type naming rules.
cTypes :: Type -> String
cTypes t = case t of
  Bool   -> "unsigned char"
  Int8   -> "signed char"
  Int16  -> "signed short"
  Int32  -> "signed long"
  Int64  -> "signed long long"
  Word8  -> "unsigned char"
  Word16 -> "unsigned short"
  Word32 -> "unsigned long"
  Word64 -> "unsigned long long"
  Float  -> "float"
  Double -> "double"

-- | C99 type naming rules.
c99Types :: Type -> String
c99Types t = case t of
  Bool   -> "uint8_t"
  Int8   -> "int8_t"
  Int16  -> "int16_t"
  Int32  -> "int32_t"
  Int64  -> "int64_t"
  Word8  -> "uint8_t"
  Word16 -> "uint16_t"
  Word32 -> "uint32_t"
  Word64 -> "uint64_t"
  Float  -> "float"
  Double -> "double"

codeUE :: Config -> [(UE, String)] -> String -> (UE, String) -> String
codeUE config ues d (ue, n) = d ++ cType config (typeOf ue) ++ " " ++ n ++ " = " ++ basic operands ++ ";\n"
  where
  operands = map (fromJust . flip lookup ues) $ ueUpstream ue
  basic :: [String] -> String
  basic operands = concat $ case ue of
    UVRef (UV (Array ua@(UA _ n _) _)) -> [arrayIndex config ua a, " /* ", n, " */ "]
    UVRef (UV (External n _)) -> [n]
    UCast _ _            -> ["(", cType config (typeOf ue), ") ", a]
    UConst c             -> [showConst c]
    UAdd _ _             -> [a, " + ", b]
    USub _ _             -> [a, " - ", b]
    UMul _ _             -> [a, " * ", b]
    UDiv _ _             -> [a, " / ", b]
    UMod _ _             -> [a, " % ", b]
    UNot _               -> ["! ", a]
    UAnd _               -> intersperse " && " operands
    UBWNot _             -> ["~ ", a]
    UBWAnd _ _           -> [a, " & ", b]
    UBWOr  _ _           -> [a, " | ", b]
    UShift _ n           -> (if n >= 0 then [a, " << ", show n] else [a, " >> ", show (negate n)])
    UEq  _ _             -> [a, " == ", b]
    ULt  _ _             -> [a, " < " , b]
    UMux _ _ _           -> [a, " ? " , b, " : ", c]
    UF2B _               -> ["*((", ct Word32, " *) &(", a, "))"]
    UD2B _               -> ["*((", ct Word64, " *) &(", a, "))"]
    UB2F _               -> ["*((", ct Float , " *) &(", a, "))"]
    UB2D _               -> ["*((", ct Double, " *) &(", a, "))"]
    where
    ct = cType config
    a = head operands
    b = operands !! 1
    c = operands !! 2

writeC :: Name -> Config -> [[[Rule]]] -> ([Const], [Const], [Const], [Const]) -> IO ()
writeC name config periods (init8, init16, init32, init64) = do
  writeFile (name ++ ".c") c
  writeFile (name' ++ "CoverageData.hs") cov
  where
  name' = toUpper (head name) : tail name
  c = unlines
    [ cPreCode config
    , "static " ++ cType config Word64 ++ " __clock = 0;"
    , ruleCoverage config $ "static const " ++ cType config Word32 ++ " __coverage_len = " ++ show covLen ++ ";"
    , ruleCoverage config $ "static " ++ cType config Word32 ++ " __coverage[" ++ show covLen ++ "] = {" ++ (concat $ intersperse ", " $ replicate covLen "0") ++ "};"
    , ruleCoverage config $ "static " ++ cType config Word32 ++ " __coverage_index = 0;"
    , mi Word8 init8 ++ mi Word16 init16 ++ mi Word32 init32 ++ mi Word64 init64
    , concatMap (codeRule config topo') $ concat $ concat periods
    , "void " ++ (if null (cFuncName config) then name else cFuncName config) ++ "(void) {"
    , concatMap codePeriod $ zip [1..] periods
    , "  __clock = __clock + 1;"
    , "}"
    , cPostCode config
    ]

  mi = memoryInit config

  rules = concat $ concat periods

  cov = unlines
    [ "module " ++ name' ++ "CoverageData (coverageData) where"
    , ""
    , "-- | Encoding of rule coverage: (rule name, coverage array index, coverage bit)"
    , "coverageData :: [(String, (Int, Int))]"
    , "coverageData = " ++ show [ (ruleName r, (div (ruleId r) 32, mod (ruleId r) 32)) | r <- rules ]
    ]

  topo' = topo 0
  covLen = 1 + div (maximum $ map ruleId rules) 32

ruleCoverage :: Config -> String -> String
ruleCoverage config s = if cRuleCoverage config then s else ""

coverage :: Config -> String -> String
coverage config s = if cCover config then s else ""

assertion :: Config -> String -> String
assertion config s = if cAssert config then s else ""

memoryInit :: Config -> Type -> [Const] -> String
memoryInit _ _ [] = ""
memoryInit config t init = "static " ++ cType config t ++ " " ++ memory t ++ "[" ++ show (length init) ++ "] = {" ++ drop 2 (concat [", " ++ format a | a <- init ]) ++ "};\n"
  where
  format :: Const -> String
  format c = case c of
    CBool True  -> "1"
    CBool False -> "0"
    CInt8   a   -> show a
    CInt16  a   -> show a
    CInt32  a   -> show a ++ "L"
    CInt64  a   -> show a ++ "LL"
    CWord8  a   -> show a
    CWord16 a   -> show a
    CWord32 a   -> show a ++ "UL"
    CWord64 a   -> show a ++ "ULL"
    CFloat  a   -> show $ floatBits  a
    CDouble a   -> show $ doubleBits a

floatBits :: Float -> Word32
floatBits = unsafeCoerce

doubleBits :: Double -> Word64
doubleBits = unsafeCoerce

memory :: Width a => a -> String
memory a = "__memory" ++ show (if width a == 1 then 8 else width a)

codeRule :: Config -> ([UE] -> [(UE, String)]) -> Rule -> String
codeRule config topo rule =
  "/* " ++ show rule ++ " */\n" ++
  "static void __r" ++ show (ruleId rule) ++ "(void) {\n" ++
  concatMap (codeUE    config ues "  ") ues ++
  "  if (" ++ id (ruleEnable rule) ++ ") {\n" ++
  concatMap codeAction (ruleActions rule) ++
  ruleCoverage config ("    __coverage[" ++ covWord ++ "] = __coverage[" ++ covWord ++ "] | (1 << " ++ covBit ++ ");\n") ++
  concat [ assertion config ("    " ++ cAssertName config ++ "(" ++ show name ++ ", " ++ id check ++ ");\n") | (name, check) <- ruleAsserts rule ] ++
  concat [ coverage  config ("    " ++ cCoverName  config ++ "(" ++ show name ++ ", " ++ id check ++ ");\n") | (name, check) <- ruleCovers  rule ] ++
  "  }\n" ++
  concatMap codeAssign (ruleAssigns rule) ++
  "}\n\n"
  where
  ues = topo $ allUEs rule
  id ue = fromJust $ lookup ue ues

  codeAction :: (([String] -> String), [UE]) -> String
  codeAction (f, args) = "    " ++ f (map id args) ++ ";\n"

  covWord = show $ div (ruleId rule) 32
  covBit  = show $ mod (ruleId rule) 32

  codeAssign :: (UV, UE) -> String
  codeAssign (UV (Array ua@(UA _ n _) i), ue) = "  " ++ arrayIndex config ua (id i) ++ " = " ++ id ue ++ ";  /* " ++ n ++ " */\n"
  codeAssign (UV (External n _), ue) = "  " ++ n ++ " = " ++ id ue ++ ";\n"

arrayIndex :: Config -> UA -> String -> String
arrayIndex config (UA addr _ c) i = "((" ++ cType config (typeOf (head c)) ++ " *) (" ++ memory (head c) ++ " + " ++ show addr ++ "))[" ++ i ++ "]"

codePeriod :: (Int, [[Rule]]) -> String
codePeriod (period, cycles) = concatMap (codeCycle period) $ zip [0..] cycles

codeCycle :: Int -> (Int, [Rule]) -> String
codeCycle period (cycle, _) | cycle >= period = error "Code.codeCycle"
codeCycle _ (_, rules)      | null rules = ""
codeCycle period (cycle, rules) =
  "  if (__clock % " ++ show period ++ " == " ++ show cycle ++ ") {\n" ++
  concatMap (\ r -> "    __r" ++ show (ruleId r) ++ "();  /* " ++ show r ++ " */\n") rules ++
  "  }\n"

e :: Int -> String
e i = "__" ++ show i

-- | Topologically sorts a list of expressions and subexpressions.
topo :: Int -> [UE] -> [(UE, String)]
topo start ues = reverse ues'
  where
  (_, ues') = foldl collect (start, []) ues
  collect :: (Int, [(UE, String)]) -> UE -> (Int, [(UE, String)])
  collect (n, ues) ue | any ((== ue) . fst) ues = (n, ues)
  collect (n, ues) ue = (n' + 1, (ue, e n') : ues') where (n', ues') = foldl collect (n, ues) $ ueUpstream ue

-- | Number of UE's computed in rule.
ruleComplexity :: Rule -> Int
ruleComplexity = length . topo 0 . allUEs