module Simplification (simplifyProg) where
import Prelude hiding ( or,fail,catch )
import MetaProgramming.FlatCurry
import MetaProgramming.FlatCurryGoodies hiding ( freeVars )
import qualified MetaProgramming.FlatCurryGoodies as FCG
import List ( sortBy, groupBy, partition )
data Int' = Neg Nat | Zero | Pos Nat
data Nat = IHi | O Nat | I Nat
simplifyProg :: Prog -> Prog
simplifyProg = simplified []
simplified :: [FuncDecl] -> Prog -> Prog
simplified preludeFuncs prog =
updProgExps (runSimp next rs . evalFamilySimp tExpr opt) prog
where
opt = elimSimpleLet `or`
elimIntLit `or`
elimFailBranch `or`
elimCase `or`
propagate
next = 1 + maxlist (0:allVarsInProg prog)
rs = map rule (filter isInlined (preludeFuncs ++ progFuncs prog))
rule func = (funcName func, funcRule func)
-- inline only flat constants and if_then_else
isInlined func =
not (isExternal func) &&
(funcName func == (preludeName,if_then_elseName) ||
isConstant (funcBody func) ||
isVar (funcBody func))
isConstant :: Expr -> Bool
isConstant exp = isLit exp || (isConsCall exp && null (combArgs exp))
-- elimination of let bindings that occur only once in right-hand side
elimSimpleLet :: Expr -> Simp Expr
elimSimpleLet exp
| isLet exp && (null keptBs || not (null simpBs))
= ret (let_ keptBs (replace simpBs e))
| otherwise = fail
where
Let bs e = exp
(simpBs,keptBs') = partition isSimpleBind bs
keptBs = map (\ (v,e) -> (v,replace simpBs e)) keptBs'
freeVarsInBinds = concatMap (freeVars . snd) bs
isSimpleBind (x,e) =
isVar e || not (x `elem` freeVarsInBinds) && x `isUniqueIn` exp
isUniqueIn :: VarIndex -> Expr -> Bool
x `isUniqueIn` exp = null xs || null (tail xs)
where xs = filter (x==) (freeVars exp)
-- elimination of integer literals and patterns
elimIntLit :: Expr -> Simp Expr
elimIntLit exp
| isLit exp && isIntLit lit = ret (intLitToCons lit)
| isCase exp && any (isIntPattern . branchPattern) (caseBranches exp)
= flatCase ct [e] (map nestedBranch bs) fail
| otherwise = fail
where
lit = literal exp
Case ct e bs = exp
isIntLit :: Literal -> Bool
isIntLit exp = case exp of Intc _ -> True; _ -> False
intLitToCons :: Literal -> Expr
intLitToCons (Intc n) = int_ (intToInt' n)
isIntPattern :: Pattern -> Bool
isIntPattern pat = not (isConsPattern pat) && isIntLit (patLiteral pat)
nestedBranch :: BranchExpr -> ([Expr],Expr)
nestedBranch (Branch pat exp) =
case patExpr pat of
Lit (Intc n) -> ([int_ (intToInt' n)], exp)
pexp -> ([pexp], exp)
-- flattens a case expression.
-- the branches are given as pairs of possibly nested constructor terms
-- and arbitrary right hand sides.
-- multiple arguments of patterns are matched from left to right!
flatCase :: CaseType -> [Expr] -> [([Expr],Expr)] -> Simp Expr -> Simp Expr
flatCase _ [] [] err = err
flatCase _ [] bs@(_:_) _ = ret (foldr1 (?~) (map snd bs))
flatCase ct (e:es) bs err
| all isVar pats
= flatCase ct es (map replaceVar bs) err
| not (null bs) && all isConsCall pats
= liftSimp (Case ct e) (mapSimp branch groupedBs)
| otherwise
= foldr (flatCase ct (e:es)) err (groupBy (lift2 sameKind (head . fst)) bs)
where
pats = map (head . fst) bs
groupedBs = reorderBy (lift2 cmpQName (combName . head . fst)) bs
sameKind p1 p2 = all isVar [p1,p2] || all isConsCall [p1,p2]
replaceVar (Var x:ps,rhs) = (ps,Let [(x,e)] rhs)
branch gbs@((Comb _ name args : _, _) : _) =
nextVars (length args) .>>= \xs ->
liftSimp (Branch (Pattern name xs))
(flatCase ct (map Var xs ++ es) (map extend gbs) err)
extend (Comb _ _ args : ps, rhs) = (args ++ ps, rhs)
-- elimination of failing branches in case expressions
elimFailBranch :: Expr -> Simp Expr
elimFailBranch exp
| isCase exp && (null bs || any isFailBranch bs)
= ret (replaceBranches exp (filter (not . isFailBranch) bs))
| otherwise = fail
where
bs = caseBranches exp
isFailBranch :: BranchExpr -> Bool
isFailBranch = isFailed . branchExpr
isFailed :: Expr -> Bool
isFailed exp = isFuncCall exp && combName exp == (preludeName,failedName)
replaceBranches :: Expr -> [BranchExpr] -> Expr
replaceBranches (Case ct e _) bs
| null bs = failed_
| otherwise = Case ct e bs
-- elimination of case applied to constructor terms
elimCase :: Expr -> Simp Expr
elimCase exp
| isCase exp && isConsCall scr = match scr (caseBranches exp)
| otherwise = fail
where
scr = caseExpr exp
match :: Expr -> [BranchExpr] -> Simp Expr
match (Comb _ name args) bs
| null xs = ret failed_
| otherwise
= nextVars (length ys) .>>= \zs ->
ret $ Let (zip zs args) (replace (zip ys (map Var zs)) exp)
where
xs = filter ((name==) . patCons . branchPattern) bs
Branch pat exp : _ = xs
ys = patArgs pat
-- inlining of functions whose rule is provided
propagate :: Expr -> Simp Expr
propagate exp
| isFuncCall exp = fetchRule (combName exp) .>>= ret . inline exp
| otherwise = fail
inline :: Expr -> Rule -> Expr
inline (Comb _ _ args) (Rule params body) = Let (zip params args) body
-- traversables
tInt :: Traversable Int' Nat
tInt Zero = noChildren Zero
tInt (Pos n) = ([n], \ [n] -> Pos n)
tInt (Neg n) = ([n], \ [n] -> Neg n)
tNat :: Traversable Nat Nat
tNat IHi = noChildren IHi
tNat (O n) = ([n], \ [n] -> O n)
tNat (I n) = ([n], \ [n] -> I n)
tExpr :: Traversable Expr Expr
tExpr exp =
case exp of
Comb ct name args -> (args, Comb ct name)
Let bs e -> let (xs,es) = unzip bs in (e:es, \ (e:es) -> Let (zip xs es) e)
Free xs e -> ([e], \ [e] -> Free xs e)
Or e1 e2 -> ([e1,e2], \ [e1,e2] -> Or e1 e2)
Case ct e bs -> let (ps,es) = unzip (map branch bs)
in (e:es, \ (e:es) -> Case ct e (zipWith Branch ps es))
_ -> noChildren exp
where
branch (Branch p e) = (p,e)
tBranchExpr :: Traversable BranchExpr Expr
tBranchExpr (Branch pat exp) = ([exp], \ [exp] -> Branch pat exp)
tTypeExpr :: Traversable TypeExpr TypeExpr
tTypeExpr typ =
case typ of
FuncType dom ran -> ([dom,ran], \ [dom,ran] -> FuncType dom ran)
TCons name args -> (args, TCons name)
_ -> noChildren typ
-- comparison
type Ord' a = a -> a -> Ordering
reorderBy :: Ord' a -> [a] -> [[a]]
reorderBy cmp = groupBy eq . sortBy cmp
where
eq x y = cmp x y == EQ
cmpQName :: Ord' QName
cmpQName = cmpPair cmpString cmpString
cmpPair :: Ord' a -> Ord' b -> Ord' (a,b)
cmpPair cmpa cmpb (a1,b1) (a2,b2) =
case cmpa a1 a2 of
EQ -> cmpb b1 b2
cmp -> cmp
-- creating FlatCurry expressions
let_ bs e = if null bs then e else Let bs e
preludeName = "Prelude"
if_then_elseName = "if_then_else"
failedName = "failed"
failed_ :: Expr
failed_ = Comb FuncCall (preludeName,failedName) []
zero_ = Comb ConsCall (preludeName, "Zero") []
pos_ n = Comb ConsCall (preludeName, "Pos") [n]
neg_ n = Comb ConsCall (preludeName, "Neg") [n]
iHi_ = Comb ConsCall (preludeName, "IHi") []
o_ n = Comb ConsCall (preludeName, "O") [n]
i_ n = Comb ConsCall (preludeName, "I") [n]
x ?~ y = Comb FuncCall (preludeName, "?") [x,y]
int_ :: Int' -> Expr
int_ = foldChildren tInt tNat intExp natExp
where
intExp Zero _ = zero_
intExp (Pos _) [n] = pos_ n
intExp (Neg _) [n] = neg_ n
natExp IHi _ = iHi_
natExp (O _) [n] = o_ n
natExp (I _) [n] = i_ n
-- auxiliary functions
lift2 :: (a -> a -> c) -> (b -> a) -> (b -> b -> c)
lift2 op f x y = op (f x) (f y)
stripSuffix :: String -> String -> String
stripSuffix suf str
| suf `isSuffixOf` str = take (length str - length suf) str
| otherwise = str
isSuffixOf, isPrefixOf :: Eq a => [a] -> [a] -> Bool
suf `isSuffixOf` l = reverse suf `isPrefixOf` reverse l
[] `isPrefixOf` _ = True
(x:xs) `isPrefixOf` (y:ys) = x==y && xs `isPrefixOf` ys
-- compute free variables of expression
freeVars :: Expr -> [VarIndex]
freeVars = outOfScopeVars []
outOfScopeVars :: [VarIndex] -> Expr -> [VarIndex]
outOfScopeVars scope exp = fold tExpr vars exp scope
where
vars exp cs scope =
case (exp,cs) of
(Var n,_) -> if n `elem` scope then [] else [n]
(Let bs _,_) ->
concatMap ( $ filter (not . (`elem` map fst bs)) scope) cs
(Free vs _,[e]) -> e (filter (not . (`elem` vs)) scope)
(Case _ _ bs,e:es) ->
e scope ++ concat (zipWith (scopeBranch scope) bs es)
_ -> concatMap ( $ scope) cs
scopeBranch scope (Branch pat _) e
| isConsPattern pat = e (filter (not . (`elem` patArgs pat)) scope)
| otherwise = e scope
-- replace free variables in expression according to environment
type Env = [(VarIndex,Expr)]
replace :: Env -> Expr -> Expr
replace env exp
| isVar exp = fromEnv [] (varNr exp) env
| isLet exp = mapChildren tExpr (replace (removeLetBinds exp env)) exp
| isFree exp = mapChildren tExpr (replace (remove (FCG.freeVars exp) env)) exp
| isCase exp = let Case ct e bs = exp
in Case ct (replace env e) (map (replaceBranch env) bs)
| otherwise = mapChildren tExpr (replace env) exp
fromEnv :: [VarIndex] -> VarIndex -> Env -> Expr
fromEnv is i env = case lookup i env of
Nothing -> Var i
Just (Var j) -> if elem j is then Comb FuncCall ("Prelude","failed") []
else fromEnv (j:is) j env
Just e -> replace env e
remove :: [VarIndex] -> Env -> Env
remove xs env = filter (not . (`elem`xs) . fst) env
removeLetBinds :: Expr -> Env -> Env
removeLetBinds = remove . map fst . letBinds
replaceBranch :: Env -> BranchExpr -> BranchExpr
replaceBranch env b =
mapChildren tBranchExpr (replace (remove (patArgs (branchPattern b)) env)) b
maxlist :: [Int] -> Int
maxlist [n] = n
maxlist (n:m:ns) = max n (maxlist (m:ns))
--- A datatype is Traversable
if it defines a function
--- that can decompose a value into a list of children of the same type
--- and recombine new children to a new value of the original type.
---
type Traversable a b = a -> ([b], [b] -> a)
--- Traversal function for constructors without children.
---
noChildren :: Traversable a b
noChildren x = ([], const x)
--- Yields the children of a value.
---
children :: Traversable a b -> a -> [b]
children tr = fst . tr
--- Replaces the children of a value.
---
replaceChildren :: Traversable a b -> a -> [b] -> a
replaceChildren tr = snd . tr
--- Applies the given function to each child of a value.
---
mapChildren :: Traversable a b -> (b -> b) -> a -> a
mapChildren tr f x = replaceChildren tr x (map f (children tr x))
--- Computes a list of the given value, its children, those children, etc.
---
family :: Traversable a a -> a -> [a]
family tr x = familyFL tr x []
--- Computes a list of family members of the children of a value.
--- The value and its children can have different types.
---
childFamilies :: Traversable a b -> Traversable b b -> a -> [b]
childFamilies tra trb x = childFamiliesFL tra trb x []
-- implementation of 'family' with functional lists for efficiency reasons
type FunList a = [a] -> [a]
familyFL :: Traversable a a -> a -> FunList a
familyFL tr x xs = x : childFamiliesFL tr tr x xs
childFamiliesFL :: Traversable a b -> Traversable b b -> a -> FunList b
childFamiliesFL tra trb x xs = concatFL (map (familyFL trb) (children tra x)) xs
--- Concatenates a list of functional lists.
---
concatFL :: [FunList a] -> FunList a
concatFL [] ys = ys
concatFL (x:xs) ys = x (concatFL xs ys)
--- Applies the given function to each member of the family of a value.
--- Proceeds bottom-up.
---
mapFamily :: Traversable a a -> (a -> a) -> a -> a
mapFamily tr f = f . mapChildFamilies tr tr f
--- Applies the given function to each member of the families of the children
--- of a value. The value and its children can have different types.
--- Proceeds bottom-up.
---
mapChildFamilies :: Traversable a b -> Traversable b b -> (b -> b) -> a -> a
mapChildFamilies tra trb = mapChildren tra . mapFamily trb
--- Applies the given function to each member of the family of a value
--- as long as possible. On each member of the family of the result the given
--- function will yield Nothing
.
--- Proceeds bottom-up.
---
evalFamily :: Traversable a a -> (a -> Maybe a) -> a -> a
evalFamily tr f = mapFamily tr g
where g x = maybe x (mapFamily tr g) (f x)
--- Applies the given function to each member of the families of the children
--- of a value as long as possible.
--- Similar to 'evalFamily'.
---
evalChildFamilies :: Traversable a b -> Traversable b b
-> (b -> Maybe b) -> a -> a
evalChildFamilies tra trb = mapChildren tra . evalFamily trb
--- Implements a traversal similar to a fold with possible default cases.
---
fold :: Traversable a a -> (a -> [r] -> r) -> a -> r
fold tr f = foldChildren tr tr f f
--- Fold the children and combine the results.
---
foldChildren :: Traversable a b -> Traversable b b
-> (a -> [rb] -> ra) -> (b -> [rb] -> rb) -> a -> ra
foldChildren tra trb f g a = f a (map (fold trb g) (children tra a))
infixl 1 .>>=, .>>
type Rules = [(QName,Rule)]
type Simp a = VarIndex -> Rules -> Maybe (a,VarIndex)
runSimp :: Int -> Rules -> Simp a -> a
runSimp n rs o =
maybe (error "Simplification.runSimp: simplification fails") fst (o n rs)
ret :: a -> Simp a
ret x n _ = Just (x,n)
(.>>=) :: Simp a -> (a -> Simp b) -> Simp b
(oa .>>= f) n rs =
case oa n rs of
Nothing -> Nothing
Just (a,n) -> f a n rs
(.>>) :: Simp b -> Simp a -> Simp a
o .>> oa = o .>>= const oa
liftSimp :: (a -> b) -> Simp a -> Simp b
liftSimp f oa = oa .>>= ret . f
fail :: Simp a
fail _ _ = Nothing
catch :: Simp a -> Simp a -> Simp a
catch o1 o2 n rs = maybe (o2 n rs) Just (o1 n rs)
or :: (a -> Simp b) -> (a -> Simp b) -> a -> Simp b
or f g a = catch (f a) (g a)
nextVar :: Simp VarIndex
nextVar n _ = Just (n,n+1)
nextVars :: Int -> Simp [VarIndex]
nextVars n = sequenceSimp (replicate n nextVar)
fetchRule :: QName -> Simp Rule
fetchRule name n rs = maybe Nothing defRule (lookup name rs)
where
defRule (Rule args body) =
let arity = length args
args' = take arity [n ..]
in Just (Rule args' (replace (zip args (map Var args')) body)
,n+arity)
defRule (External _) = Nothing
sequenceSimp :: [Simp a] -> Simp [a]
sequenceSimp [] = ret []
sequenceSimp (ox:oxs) = ox .>>= \x -> sequenceSimp oxs .>>= \xs -> ret (x:xs)
mapSimp :: (a -> Simp b) -> [a] -> Simp [b]
mapSimp f = sequenceSimp . map f
replaceChildrenSimp :: Traversable a b -> a -> Simp [b] -> Simp a
replaceChildrenSimp tr = liftSimp . replaceChildren tr
mapChildrenSimp :: Traversable a b -> (b -> Simp b) -> a -> Simp a
mapChildrenSimp tr f a = replaceChildrenSimp tr a (mapSimp f (children tr a))
mapFamilySimp :: Traversable a a -> (a -> Simp a) -> a -> Simp a
mapFamilySimp tr f a = mapChildFamiliesSimp tr tr f a .>>= f
mapChildFamiliesSimp :: Traversable a b -> Traversable b b
-> (b -> Simp b) -> a -> Simp a
mapChildFamiliesSimp tra trb = mapChildrenSimp tra . mapFamilySimp trb
evalFamilySimp :: Traversable a a -> (a -> Simp a) -> a -> Simp a
evalFamilySimp tr f = mapFamilySimp tr g
where g a = catch (f a .>>= mapFamilySimp tr g) (ret a)
evalChildFamiliesSimp :: Traversable a b -> Traversable b b
-> (b -> Simp b) -> a -> Simp a
evalChildFamiliesSimp tra trb = mapChildrenSimp tra . evalFamilySimp trb
cmpString :: String -> String -> Ordering
cmpString = compare
intToInt' :: Prelude.Integral a => a -> Int'
intToInt' n = case Prelude.compare n 0 of
LT -> Neg (intToNat (Prelude.abs n))
EQ -> Zero
GT -> Pos (intToNat (Prelude.abs n))
intToNat :: Prelude.Integral a => a -> Nat
intToNat n = case Prelude.mod n 2 of
1 -> if m Prelude.== 0 then IHi else I (intToNat m)
0 -> O (intToNat m)
where m = Prelude.div n 2