module Language.ImProve
(
E
, V
, AllE
, NumE
, Name
, true
, false
, constant
, ref
, not_
, (&&.)
, (||.)
, and_
, or_
, any_
, all_
, imply
, (==.)
, (/=.)
, (<.)
, (<=.)
, (>.)
, (>=.)
, min_
, minimum_
, max_
, maximum_
, limit
, (*.)
, (/.)
, div_
, mod_
, mux
, linear
, scope
, bool
, bool'
, int
, int'
, float
, float'
, input
, Stmt
, (<==)
, ifelse
, if_
, incr
, decr
, assert
, assume
, compile
) where
import Control.Monad
import Data.List
import Data.Ratio
infixl 7 *., /., `div_`, `mod_`
infix 4 ==., /=., <., <=., >., >=.
infixl 3 &&.
infixl 2 ||.
infixr 1 <==
type Name = String
data V a
= V [Name] a
| VIn [Name]
class AllE a where
showConst :: a -> String
showType :: a -> String
zero :: (Name -> a -> Stmt (V a)) -> a
instance AllE Bool where
showConst a = case a of
True -> "1"
False -> "0"
showType _ = "int"
zero = const False
instance AllE Int where
showConst = show
showType _ = "int"
zero = const 0
instance AllE Float where
showConst = show
showType _ = "float"
zero = const 0
class AllE a => NumE a
instance NumE Int
instance NumE Float
data E a where
Ref :: AllE a => V a -> E a
Const :: AllE a => a -> E a
Add :: NumE a => E a -> E a -> E a
Sub :: NumE a => E a -> E a -> E a
Mul :: NumE a => E a -> a -> E a
Div :: NumE a => E a -> a -> E a
Mod :: E Int -> Int -> E Int
Not :: E Bool -> E Bool
And :: E Bool -> E Bool -> E Bool
Or :: E Bool -> E Bool -> E Bool
Eq :: AllE a => E a -> E a -> E Bool
Lt :: NumE a => E a -> E a -> E Bool
Gt :: NumE a => E a -> E a -> E Bool
Le :: NumE a => E a -> E a -> E Bool
Ge :: NumE a => E a -> E a -> E Bool
Mux :: AllE a => E Bool -> E a -> E a -> E a
instance Show (E a) where show = undefined
instance Eq (E a) where (==) = undefined
instance (Num a, AllE a, NumE a) => Num (E a) where
(+) = Add
() = Sub
(*) = error "general multiplication not supported, use (*.)"
negate a = 0 a
abs a = mux (a <. 0) (negate a) a
signum a = mux (a ==. 0) 0 $ mux (a <. 0) (1) 1
fromInteger = Const . fromInteger
instance Fractional (E Float) where
(/) = error "general division not supported, use (/.)"
recip a = 1 / a
fromRational r = Const $ fromInteger (numerator r) / fromInteger (denominator r)
true :: E Bool
true = Const True
false :: E Bool
false = Const False
constant :: AllE a => a -> E a
constant = Const
not_ :: E Bool -> E Bool
not_ = Not
(&&.) :: E Bool -> E Bool -> E Bool
(&&.) = And
(||.) :: E Bool -> E Bool -> E Bool
(||.) = Or
and_ :: [E Bool] -> E Bool
and_ = foldl (&&.) true
or_ :: [E Bool] -> E Bool
or_ = foldl (||.) false
all_ :: (a -> E Bool) -> [a] -> E Bool
all_ f a = and_ $ map f a
any_ :: (a -> E Bool) -> [a] -> E Bool
any_ f a = or_ $ map f a
imply :: E Bool -> E Bool -> E Bool
imply a b = not_ a ||. b
(==.) :: AllE a => E a -> E a -> E Bool
(==.) = Eq
(/=.) :: AllE a => E a -> E a -> E Bool
a /=. b = not_ (a ==. b)
(<.) :: NumE a => E a -> E a -> E Bool
(<.) = Lt
(>.) :: NumE a => E a -> E a -> E Bool
(>.) = Gt
(<=.) :: NumE a => E a -> E a -> E Bool
(<=.) = Le
(>=.) :: NumE a => E a -> E a -> E Bool
(>=.) = Ge
min_ :: NumE a => E a -> E a -> E a
min_ a b = mux (a <=. b) a b
minimum_ :: NumE a => [E a] -> E a
minimum_ = foldl1 min_
max_ :: NumE a => E a -> E a -> E a
max_ a b = mux (a >=. b) a b
maximum_ :: NumE a => [E a] -> E a
maximum_ = foldl1 max_
limit :: NumE a => E a -> E a -> E a -> E a
limit a b i = max_ min $ min_ max i
where
min = min_ a b
max = max_ a b
(*.) :: NumE a => E a -> a -> E a
(*.) = Mul
(/.) :: E Float -> Float -> E Float
_ /. 0 = error "divide by zero (/.)"
a /. b = Div a b
div_ :: E Int -> Int -> E Int
div_ _ 0 = error "divide by zero (div_)"
div_ a b = Div a b
mod_ :: E Int -> Int -> E Int
mod_ _ 0 = error "divide by zero (mod_)"
mod_ a b = Mod a b
linear :: (Float, Float) -> (Float, Float) -> E Float -> E Float
linear (x1, y1) (x2, y2) a = a *. slope + constant inter
where
slope = (y2 y1) / (x2 x1)
inter = y1 slope * x1
ref :: AllE a => V a -> E a
ref = Ref
mux :: AllE a => E Bool -> E a -> E a -> E a
mux = Mux
scope :: Name -> Stmt a -> Stmt a
scope name (Stmt f0) = Stmt f1
where
f1 (path, items, statement) = (a, (path, Scope name items0 : items, statement1))
where
(a, (_, items0, statement1)) = f0 (path ++ [name], [], statement)
get :: Stmt ([Name], [Scope], Statement)
get = Stmt $ \ a -> (a, a)
getPath :: Stmt [Name]
getPath = do
(path, _, _) <- get
return path
put :: ([Name], [Scope], Statement) -> Stmt ()
put a = Stmt $ \ _ -> ((), a)
var :: AllE a => Name -> a -> Stmt (V a)
var name init = do
(path, items, stmt) <- get
put (path, Variable name (showType init) (showConst init) : items, stmt)
return $ V (path ++ [name]) init
input :: AllE a => (Name -> a -> Stmt (V a)) -> Name -> Stmt (E a)
input f name = do
(path, items, stmt) <- get
put (path, Variable name (showType $ zero f) (showConst $ zero f) : items, stmt)
return $ ref $ VIn (path ++ [name])
bool :: Name -> Bool -> Stmt (V Bool)
bool = var
bool' :: Name -> E Bool -> Stmt (E Bool)
bool' name value = do
a <- bool name False
a <== value
return $ ref a
int :: Name -> Int -> Stmt (V Int)
int = var
int' :: Name -> E Int -> Stmt (E Int)
int' name value = do
a <- int name 0
a <== value
return $ ref a
float :: Name -> Float -> Stmt (V Float)
float = var
float' :: Name -> E Float -> Stmt (E Float)
float' name value = do
a <- float name 0
a <== value
return $ ref a
incr :: V Int -> Stmt ()
incr a = a <== ref a + 1
decr :: V Int -> Stmt ()
decr a = a <== ref a 1
data Statement
= AssignBool (V Bool ) (E Bool )
| AssignInt (V Int ) (E Int )
| AssignFloat (V Float) (E Float)
| Branch (E Bool) Statement Statement
| Sequence Statement Statement
| Assert [Name] (E Bool)
| Assume [Name] (E Bool)
| Null
data Stmt a = Stmt (([Name], [Scope], Statement) -> (a, ([Name], [Scope], Statement)))
instance Monad Stmt where
return a = Stmt $ \ s -> (a, s)
(Stmt f1) >>= f2 = Stmt f3
where
f3 s1 = f4 s2
where
(a, s2) = f1 s1
Stmt f4 = f2 a
statement :: Statement -> Stmt ()
statement a = Stmt $ \ (path, scope, statement) -> ((), (path, scope, Sequence statement a))
evalStmt :: [Name] -> [Scope] -> Stmt () -> ([Name], [Scope], Statement)
evalStmt path items (Stmt f) = snd $ f (path, items, Null)
class Assign a where (<==) :: V a -> E a -> Stmt ()
instance Assign Bool where a <== b = statement $ AssignBool a b
instance Assign Int where a <== b = statement $ AssignInt a b
instance Assign Float where a <== b = statement $ AssignFloat a b
assert :: Name -> E Bool -> Stmt ()
assert a b = do
path <- getPath
statement $ Assert (path ++ [a]) b
assume :: Name -> E Bool -> Stmt ()
assume a b = do
path <- getPath
statement $ Assume (path ++ [a]) b
ifelse :: E Bool -> Stmt () -> Stmt () -> Stmt ()
ifelse cond onTrue onFalse = do
(path, items, stmt) <- get
let (_, items1, stmt1) = evalStmt path items onTrue
(_, items2, stmt2) = evalStmt path items1 onFalse
put (path, items2, stmt)
statement $ Branch cond stmt1 stmt2
if_ :: E Bool -> Stmt () -> Stmt()
if_ cond stmt = ifelse cond stmt $ return ()
compile :: Name -> Stmt () -> IO ()
compile name program = do
writeFile (name ++ ".c") $
"// Generated by ImProve.\n\n"
++ "#include <assert.h>\n\n"
++ codeVariables True scope ++ "\n"
++ "void " ++ name ++ "() {\n"
++ indent (codeStmt stmt)
++ "}\n\n"
writeFile (name ++ ".h") $
"// Generated by ImProve.\n\n"
++ codeVariables False scope ++ "\n"
++ "void " ++ name ++ "(void);\n\n"
where
(_, items, stmt) = evalStmt [name ++ "Variables"] [] program
scope = Scope (name ++ "Variables") items
varName :: V a -> String
varName a = intercalate "." names
where
names = case a of
V names _ -> names
VIn names -> names
codeStmt :: Statement -> String
codeStmt a = case a of
AssignBool a b -> varName a ++ " = " ++ codeExpr b ++ ";\n"
AssignInt a b -> varName a ++ " = " ++ codeExpr b ++ ";\n"
AssignFloat a b -> varName a ++ " = " ++ codeExpr b ++ ";\n"
Branch a b Null -> "if (" ++ codeExpr a ++ ") {\n" ++ indent (codeStmt b) ++ "}\n"
Branch a b c -> "if (" ++ codeExpr a ++ ") {\n" ++ indent (codeStmt b) ++ "}\nelse {\n" ++ indent (codeStmt c) ++ "}\n"
Sequence a b -> codeStmt a ++ codeStmt b
Assert names a -> "// assert " ++ intercalate "." names ++ "\nassert(" ++ codeExpr a ++ ");\n"
Assume names a -> "// assume " ++ intercalate "." names ++ "\nassert(" ++ codeExpr a ++ ");\n"
Null -> ""
codeExpr :: E a -> String
codeExpr a = case a of
Ref a -> varName a
Const a -> showConst a
Add a b -> group [codeExpr a, "+", codeExpr b]
Sub a b -> group [codeExpr a, "-", codeExpr b]
Mul a b -> group [codeExpr a, "*", showConst b]
Div a b -> group [codeExpr a, "/", showConst b]
Mod a b -> group [codeExpr a, "%", showConst b]
Not a -> group ["!", codeExpr a]
And a b -> group [codeExpr a, "&&", codeExpr b]
Or a b -> group [codeExpr a, "||", codeExpr b]
Eq a b -> group [codeExpr a, "==", codeExpr b]
Lt a b -> group [codeExpr a, "<", codeExpr b]
Gt a b -> group [codeExpr a, ">", codeExpr b]
Le a b -> group [codeExpr a, "<=", codeExpr b]
Ge a b -> group [codeExpr a, ">=", codeExpr b]
Mux a b c -> group [codeExpr a, "?", codeExpr b, ":", codeExpr c]
where
group :: [String] -> String
group a = "(" ++ intercalate " " a ++ ")"
indent :: String -> String
indent = unlines . map (" " ++) . lines
data Scope
= Scope Name [Scope]
| Variable Name String String
deriving Eq
instance Ord Scope where
compare a b = case (a, b) of
(Scope a _, Scope b _) -> compare a b
(Variable a _ _, Variable b _ _) -> compare a b
(Variable _ _ _, Scope _ _) -> LT
(Scope _ _, Variable _ _ _) -> GT
codeVariables :: Bool -> Scope -> String
codeVariables define a = (if define then "" else "extern ") ++ init (init (f1 a)) ++ (if define then " =\n" ++ f2 a else "") ++ ";\n"
where
f1 a = case a of
Scope name items -> "struct { // " ++ name ++ "\n" ++ indent (concatMap f1 $ sort items) ++ "} " ++ name ++ ";\n"
Variable name typ _ -> typ ++ " " ++ name ++ ";\n"
f2 a = case a of
Scope name items -> "{ // " ++ name ++ "\n" ++ indent (intercalate ",\n" (map f2 $ sort items)) ++ "\n}"
Variable name _ init -> "/* " ++ name ++ " */ " ++ init