module Data.Presburger.Omega.SetRel where
import Data.Presburger.Omega.LowLevel
import Data.Presburger.Omega.Expr
makeLookupFunction :: [VarHandle] -> (VarHandle -> Var)
makeLookupFunction lowLevelVars =
let expVars = takeFreeVariables (length lowLevelVars)
varLookupTable = zip lowLevelVars expVars
findVar v = case lookup v varLookupTable
of Just v' -> v'
Nothing -> error "Cannot find Omega variable"
in findVar
constraintToExpr :: Bool
-> [VarHandle]
-> [Coefficient]
-> Int
-> BoolExpr
constraintToExpr isEquality boundVars terms constant =
let
findVar = makeLookupFunction boundVars
sumTerm = sumOfProductsExpr constant $ map productTerm terms
where
productTerm (Coefficient v n) = (n, [findVar v])
boolTerm = if isEquality
then testExpr IsZero sumTerm
else testExpr IsGEZ sumTerm
in boolTerm
setToExpression :: OmegaSet -> IO (Int, BoolExp)
setToExpression s = do
(setVars, conjuncts) <- queryDNFSet addEq [] addGeq [] addConjunct [] s
return (length setVars, wrapSimplifiedExpr $ disjExpr conjuncts)
where
addEq setVars exVars terms constant =
(constraintToExpr True (exVars ++ setVars) terms constant :)
addGeq setVars exVars terms constant =
(constraintToExpr False (exVars ++ setVars) terms constant :)
addConjunct _ exVars eqs geqs =
wrapExistentialVars exVars eqs geqs
relToExpression :: OmegaRel -> IO (Int, Int, BoolExp)
relToExpression s = do
(inVars, outVars, cs) <- queryDNFRelation addEq [] addGeq [] addConjunct [] s
return (length inVars, length outVars, wrapSimplifiedExpr $ disjExpr cs)
where
addEq inVars outVars exVars terms constant =
let vars = exVars ++ inVars ++ outVars
in (constraintToExpr True vars terms constant :)
addGeq inVars outVars exVars terms constant =
let vars = exVars ++ inVars ++ outVars
in (constraintToExpr False vars terms constant :)
addConjunct _ _ exVars eqs geqs =
wrapExistentialVars exVars eqs geqs
hasExistentialVars = error "relToExpression: cannot create expression"
wrapExistentialVars exVars eqs geqs = (conjunct :)
where
conjunct =
iterateN existsExpr (length exVars) $ conjExpr (geqs ++ eqs)
iterateN f n x = go n x
where go 0 x = x
go n x = go (n1) (f x)