-- | Atom C code generation. module Language.Atom.Code ( Config (..) , writeC , ruleComplexity , defaults , cTypes , c99Types ) where import Data.Char import Data.List import Data.Maybe import Data.Word import System.IO import Unsafe.Coerce import Language.Atom.Elaboration import Language.Atom.Expressions -- | C code configuration parameters. data Config = Config { cFuncName :: String -- ^ Alternative primary function name. Leave empty to use compile name. , cType :: Type -> String -- ^ C type naming rules. , cPreCode :: String -- ^ C code to insert above (includes, macros, etc.). , cPostCode :: String -- ^ C code to insert below (main, etc.). , cRuleCoverage :: Bool -- ^ Enable rule coverage tracking. , cAssert :: Bool -- ^ Enable assertion checking. , cAssertName :: String -- ^ Name of assertion function. Type: void assert(char*, cType Bool); , cCover :: Bool -- ^ Enable functional coverage accumulation. , cCoverName :: String -- ^ Name of coverage function. Type: void cover(char*, cType Bool); } -- | Default C code configuration parameters (default function name, no pre/post code, ANSI C types). defaults :: Config defaults = Config { cFuncName = "" , cType = cTypes , cPreCode = "" , cPostCode = "" , cRuleCoverage = True , cAssert = True , cAssertName = "assert" , cCover = True , cCoverName = "cover" } showConst :: Const -> String showConst c = case c of CBool c -> if c then "1" else "0" CInt8 c -> show c CInt16 c -> show c CInt32 c -> show c ++ "L" CInt64 c -> show c ++ "LL" CWord8 c -> show c CWord16 c -> show c CWord32 c -> show c ++ "UL" CWord64 c -> show c ++ "ULL" CFloat c -> show c CDouble c -> show c -- | ANSI C type naming rules. cTypes :: Type -> String cTypes t = case t of Bool -> "unsigned char" Int8 -> "signed char" Int16 -> "signed short" Int32 -> "signed long" Int64 -> "signed long long" Word8 -> "unsigned char" Word16 -> "unsigned short" Word32 -> "unsigned long" Word64 -> "unsigned long long" Float -> "float" Double -> "double" -- | C99 type naming rules. c99Types :: Type -> String c99Types t = case t of Bool -> "uint8_t" Int8 -> "int8_t" Int16 -> "int16_t" Int32 -> "int32_t" Int64 -> "int64_t" Word8 -> "uint8_t" Word16 -> "uint16_t" Word32 -> "uint32_t" Word64 -> "uint64_t" Float -> "float" Double -> "double" codeUE :: Config -> [(UE, String)] -> String -> (UE, String) -> String codeUE config ues d (ue, n) = d ++ cType config (typeOf ue) ++ " " ++ n ++ " = " ++ basic operands ++ ";\n" where operands = map (fromJust . flip lookup ues) $ ueUpstream ue basic :: [String] -> String basic operands = concat $ case ue of UVRef (UV (Array ua@(UA _ n _) _)) -> [arrayIndex config ua a, " /* ", n, " */ "] UVRef (UV (External n _)) -> [n] UCast _ _ -> ["(", cType config (typeOf ue), ") ", a] UConst c -> [showConst c] UAdd _ _ -> [a, " + ", b] USub _ _ -> [a, " - ", b] UMul _ _ -> [a, " * ", b] UDiv _ _ -> [a, " / ", b] UMod _ _ -> [a, " % ", b] UNot _ -> ["! ", a] UAnd _ -> intersperse " && " operands UBWNot _ -> ["~ ", a] UBWAnd _ _ -> [a, " & ", b] UBWOr _ _ -> [a, " | ", b] UShift _ n -> (if n >= 0 then [a, " << ", show n] else [a, " >> ", show (negate n)]) UEq _ _ -> [a, " == ", b] ULt _ _ -> [a, " < " , b] UMux _ _ _ -> [a, " ? " , b, " : ", c] UF2B _ -> ["*((", ct Word32, " *) &(", a, "))"] UD2B _ -> ["*((", ct Word64, " *) &(", a, "))"] UB2F _ -> ["*((", ct Float , " *) &(", a, "))"] UB2D _ -> ["*((", ct Double, " *) &(", a, "))"] where ct = cType config a = head operands b = operands !! 1 c = operands !! 2 writeC :: Name -> Config -> [[[Rule]]] -> ([Const], [Const], [Const], [Const]) -> IO () writeC name config periods (init8, init16, init32, init64) = do writeFile (name ++ ".c") c writeFile (name' ++ "CoverageData.hs") cov where name' = toUpper (head name) : tail name c = unlines [ cPreCode config , "static " ++ cType config Word64 ++ " __clock = 0;" , ruleCoverage config $ "static const " ++ cType config Word32 ++ " __coverage_len = " ++ show covLen ++ ";" , ruleCoverage config $ "static " ++ cType config Word32 ++ " __coverage[" ++ show covLen ++ "] = {" ++ (concat $ intersperse ", " $ replicate covLen "0") ++ "};" , ruleCoverage config $ "static " ++ cType config Word32 ++ " __coverage_index = 0;" , mi Word8 init8 ++ mi Word16 init16 ++ mi Word32 init32 ++ mi Word64 init64 , concatMap (codeRule config topo') $ concat $ concat periods , "void " ++ (if null (cFuncName config) then name else cFuncName config) ++ "(void) {" , concatMap codePeriod $ zip [1..] periods , " __clock = __clock + 1;" , "}" , cPostCode config ] mi = memoryInit config rules = concat $ concat periods cov = unlines [ "module " ++ name' ++ "CoverageData (coverageData) where" , "" , "-- | Encoding of rule coverage: (rule name, coverage array index, coverage bit)" , "coverageData :: [(String, (Int, Int))]" , "coverageData = " ++ show [ (ruleName r, (div (ruleId r) 32, mod (ruleId r) 32)) | r <- rules ] ] topo' = topo 0 covLen = 1 + div (maximum $ map ruleId rules) 32 ruleCoverage :: Config -> String -> String ruleCoverage config s = if cRuleCoverage config then s else "" coverage :: Config -> String -> String coverage config s = if cCover config then s else "" assertion :: Config -> String -> String assertion config s = if cAssert config then s else "" memoryInit :: Config -> Type -> [Const] -> String memoryInit _ _ [] = "" memoryInit config t init = "static " ++ cType config t ++ " " ++ memory t ++ "[" ++ show (length init) ++ "] = {" ++ drop 2 (concat [", " ++ format a | a <- init ]) ++ "};\n" where format :: Const -> String format c = case c of CBool True -> "1" CBool False -> "0" CInt8 a -> show a CInt16 a -> show a CInt32 a -> show a ++ "L" CInt64 a -> show a ++ "LL" CWord8 a -> show a CWord16 a -> show a CWord32 a -> show a ++ "UL" CWord64 a -> show a ++ "ULL" CFloat a -> show $ floatBits a CDouble a -> show $ doubleBits a floatBits :: Float -> Word32 floatBits = unsafeCoerce doubleBits :: Double -> Word64 doubleBits = unsafeCoerce memory :: Width a => a -> String memory a = "__memory" ++ show (if width a == 1 then 8 else width a) codeRule :: Config -> ([UE] -> [(UE, String)]) -> Rule -> String codeRule config topo rule = "/* " ++ show rule ++ " */\n" ++ "static void __r" ++ show (ruleId rule) ++ "(void) {\n" ++ concatMap (codeUE config ues " ") ues ++ " if (" ++ id (ruleEnable rule) ++ ") {\n" ++ concatMap codeAction (ruleActions rule) ++ ruleCoverage config (" __coverage[" ++ covWord ++ "] = __coverage[" ++ covWord ++ "] | (1 << " ++ covBit ++ ");\n") ++ concat [ assertion config (" " ++ cAssertName config ++ "(" ++ show name ++ ", " ++ id check ++ ");\n") | (name, check) <- ruleAsserts rule ] ++ concat [ coverage config (" " ++ cCoverName config ++ "(" ++ show name ++ ", " ++ id check ++ ");\n") | (name, check) <- ruleCovers rule ] ++ " }\n" ++ concatMap codeAssign (ruleAssigns rule) ++ "}\n\n" where ues = topo $ allUEs rule id ue = fromJust $ lookup ue ues codeAction :: (([String] -> String), [UE]) -> String codeAction (f, args) = " " ++ f (map id args) ++ ";\n" covWord = show $ div (ruleId rule) 32 covBit = show $ mod (ruleId rule) 32 codeAssign :: (UV, UE) -> String codeAssign (UV (Array ua@(UA _ n _) i), ue) = " " ++ arrayIndex config ua (id i) ++ " = " ++ id ue ++ "; /* " ++ n ++ " */\n" codeAssign (UV (External n _), ue) = " " ++ n ++ " = " ++ id ue ++ ";\n" arrayIndex :: Config -> UA -> String -> String arrayIndex config (UA addr _ c) i = "((" ++ cType config (typeOf (head c)) ++ " *) (" ++ memory (head c) ++ " + " ++ show addr ++ "))[" ++ i ++ "]" codePeriod :: (Int, [[Rule]]) -> String codePeriod (period, cycles) = concatMap (codeCycle period) $ zip [0..] cycles codeCycle :: Int -> (Int, [Rule]) -> String codeCycle period (cycle, _) | cycle >= period = error "Code.codeCycle" codeCycle _ (_, rules) | null rules = "" codeCycle period (cycle, rules) = " if (__clock % " ++ show period ++ " == " ++ show cycle ++ ") {\n" ++ concatMap (\ r -> " __r" ++ show (ruleId r) ++ "(); /* " ++ show r ++ " */\n") rules ++ " }\n" e :: Int -> String e i = "__" ++ show i -- | Topologically sorts a list of expressions and subexpressions. topo :: Int -> [UE] -> [(UE, String)] topo start ues = reverse ues' where (_, ues') = foldl collect (start, []) ues collect :: (Int, [(UE, String)]) -> UE -> (Int, [(UE, String)]) collect (n, ues) ue | any ((== ue) . fst) ues = (n, ues) collect (n, ues) ue = (n' + 1, (ue, e n') : ues') where (n', ues') = foldl collect (n, ues) $ ueUpstream ue -- | Number of UE's computed in rule. ruleComplexity :: Rule -> Int ruleComplexity = length . topo 0 . allUEs