module Language.Haskell.FreeTheorems.Variations.PolySeq.ConstraintSolver where
import Language.Haskell.FreeTheorems.Variations.PolySeq.PrettyPrint(prettyConstraint)
import Language.Haskell.FreeTheorems.Variations.PolySeq.Syntax
import Language.Haskell.FreeTheorems.Variations.PolySeq.AlgCommon
( collectOne
, substLabel
, getUsedLabels
, getUsedExtraLabels
, removeTrue
)
import qualified Data.List as List
import qualified Data.Map as Map
import Data.Generics(everywhere, mkQ, mkT)
import Language.Haskell.FreeTheorems.Variations.PolySeq.Debug
import Control.Monad(mplus)
import Data.Function
trace2 = trace_ignore
trace = trace_ignore
trace1 = trace_ignore
simplifyConstraint :: (Term,Typ,Constraint) -> (Term,Typ,Constraint)
simplifyConstraint (t',tau',c') =
let (t,tau,c) = simpConstraint (t',tau',c') in
(t,tau,checkContradiction (removeTrue c))
simpConstraint :: (Term,Typ,Constraint) -> (Term,Typ,Constraint)
simpConstraint (t,tau,c') =
let c = removeUseless c' in
case findEq c of
Just (Eq l1 l2) -> trace "found Eq by findEq" (if l1 /= (LVal Epsilon)
then simpConstraint (substLabel l2 l1 t, substLabel l2 l1 tau, substLabel l2 l1 c)
else simpConstraint (substLabel l1 l2 t, substLabel l1 l2 tau, substLabel l1 l2 c))
Nothing ->
case findGtEpsilon c of
Just (Leq _ l) -> trace "found GtEpsilon" (simpConstraint (substLabel (LVal Epsilon) l t, substLabel (LVal Epsilon) l tau, substLabel (LVal Epsilon) l c))
Nothing ->
case findLtNbr c of
Just (Leq l _) -> trace "found LtNbr" (simpConstraint (substLabel (LVal Nbr) l t, substLabel (LVal Nbr) l tau, substLabel (LVal Nbr) l c))
Nothing -> (t,tau,c)
removeUseless :: Constraint -> Constraint
removeUseless = everywhere (mkT rmUseLess)
rmUseLess c =
case c of
Conj c1 c2 -> if useLess c1 then c2 else if useLess c2 then c1 else Conj c1 c2
Impl (Eq (LVal (Epsilon)) (LVal (Epsilon))) c2 -> c2
Impl (Eq (LVal (Nbr)) (LVal (Epsilon))) _ -> Tru
_ -> if useLess c then Tru else c
useLess c =
case c of
Eq c1 c2 -> if c1 == c2 then True else False
Leq _ (LVal (Epsilon)) -> True
Leq (LVal (Nbr)) _ -> True
Leq (LVar (LabVar i)) (LVar (LabVar j)) -> if i == j then True else False
_ -> False
findEq :: Constraint -> Maybe Constraint
findEq c =
case c of
Conj c1 c2 -> findEq c1 `mplus` findEq c2
Eq _ _ -> Just c
_ -> Nothing
findGtEpsilon = collectOne (mkQ Nothing fndGtEpsilon)
fndGtEpsilon :: Constraint -> Maybe Constraint
fndGtEpsilon c =
case c of
Leq (LVal (Epsilon)) _ -> Just c
_ -> Nothing
findLtNbr = collectOne (mkQ Nothing fndLtNbr)
fndLtNbr :: Constraint -> Maybe Constraint
fndLtNbr c =
case c of
Leq _ (LVal (Nbr)) -> Just c
_ -> Nothing
checkContradiction :: Constraint -> Constraint
checkContradiction c =
case fndContr c of
Just _ -> Fls
Nothing -> c
fndContr :: Constraint -> Maybe Constraint
fndContr c =
case c of
Eq (LVal Nbr) (LVal Epsilon) -> Just c
Eq (LVal Epsilon) (LVal Nbr) -> Just c
Leq (LVal Epsilon) (LVal Nbr) -> Just c
Conj c1 c2 -> fndContr c1 `mplus` fndContr c2
_ -> Nothing
solveConstraint :: Constraint -> ([Int],[[LabVal]])
solveConstraint c =
let labConstr = getUsedLabels$c
vars = List.sort labConstr
vals = (map (snd.unzip.(Map.toAscList)) (slvConstr labConstr Map.empty c))
in
trace1 ("solveConstraint produces\n"++ show vars ++ "with values\n" ++ show vals ++ "n") (vars,vals)
slvConstr :: [Int] -> Map.Map Int LabVal -> Constraint -> [Map.Map Int LabVal]
slvConstr vars map c =
case vars of
[] -> [map]
i:is -> let c1 = substLabel (LVal Nbr) (LVar (LabVar i)) c
c2 = substLabel (LVal Epsilon) (LVar (LabVar i)) c in
solve Nbr c1 ++ solve Epsilon c2
where
solve lab c' = trace1 ("starting solve with " ++ show (prettyConstraint c')) (
case reduceConstraint [(i,lab)] c' of
Nothing -> []
Just (as,c'') ->
let map' = Map.union map (Map.fromList as)
asvars = fst.unzip$as
vars' = filter (flip notElem asvars) is
in
slvConstr vars' map' c'')
reduceConstraint :: [(Int,LabVal)] -> Constraint -> Maybe ([(Int,LabVal)],Constraint)
reduceConstraint as c' =
let c = removeTrue.removeUseless$c' in
if checkContradiction c == Fls then Nothing
else
case findEq c of
Just (Eq l1 l2) -> trace "found Eq by findEq"(
case l1 of
LVar (LabVar i) -> let LVal w = l2 in reduceConstraint ((i,w):as) (substLabel l2 l1 c)
LVal w -> let LVar (LabVar i) = l2 in reduceConstraint ((i,w):as) (substLabel l1 l2 c))
Nothing ->
case findGtEpsilon c of
Just (Leq _ (LVar (LabVar i))) -> trace "found GtEpsilon" (
reduceConstraint ((i,Epsilon):as) (substLabel (LVal Epsilon) (LVar (LabVar i)) c))
Just c' -> error ("strange Constraint: " ++ show c' ++ "\n")
Nothing ->
case findLtNbr c of
Just (Leq (LVar (LabVar i)) _) -> trace "found LtNbr" (reduceConstraint ((i,Nbr):as) (substLabel (LVal Nbr) (LVar (LabVar i)) c))
Just c' -> error ("strange Constraint: " ++ show c' ++ "\n")
Nothing -> case checkContradiction c of
Fls -> Nothing
_ -> Just (as,c)
filterTermAndTyp :: Term -> Typ -> ([Int],[[LabVal]]) -> ([Int],[[LabVal]])
filterTermAndTyp t tau res =
let labTerm = getUsedLabels t
labTermTyp = getUsedExtraLabels labTerm tau
(varList,resList) = filterLabVars labTermTyp res
in
(varList,List.nub resList)
filterTyp :: Typ -> ([Int],[[LabVal]]) -> ([Int],[[LabVal]])
filterTyp tau res =
let labTyp = getUsedLabels tau
(varList,resList) = filterLabVars labTyp res
in
(varList,List.nub resList)
filterLabVars varList (vars,vals) =
case vars of
[] -> (vars,vals)
x:xs -> let (vrs, vls) = filterLabVars varList (xs,map tail vals) in
if x `elem` varList then
(x:vrs,zipWith (:) (map head vals) vls)
else (vrs,vls)
makeTypes :: Typ -> ([Int],[[LabVal]]) -> [Typ]
makeTypes tau (vars, res) =
map (\x->substLabFromList tau (zip vars x)) res
where substLabFromList tau' l =
case l of
[] -> tau'
(i,val):xs -> substLabFromList (substLabel (LVal val) (LVar (LabVar i)) tau') xs
makeMinimalTypes :: Typ -> ([Int],[[LabVal]]) -> [Typ]
makeMinimalTypes tau (vars,valss) =
let (setVarOpt,unsetVarOpt) = List.partition (\x->(fst x) `elem` vars) (getOptimal tau)
types = makeTypes tau (vars,
getMinimal (snd.unzip$
(trace2 ("the setOptVars are: " ++ show setVarOpt ++ "\n\n")
setVarOpt))
valss)
in
concat (map (getOptTypes unsetVarOpt) types)
where
leq opts vals1 vals2 =
case opts of
[] -> True
o:os -> let v1:vs1 = vals1
v2:vs2 = vals2
in
if o == Non then leq os vs1 vs2
else if LVal v1 == o && LVal v2 /= o then False
else leq os vs1 vs2
notgeq opts vals1 vals2 =
case opts of
[] -> False
o:os -> let v1:vs1 = vals1
v2:vs2 = vals2
in
if o == Non then notgeq os vs1 vs2
else if LVal v2 == o && LVal v1 /= o then True
else notgeq os vs1 vs2
getMinimal opts valss =
case valss of
[] -> []
x:xs -> case List.find (leq opts x) xs of
Nothing -> x : getMinimal opts (filter (notgeq opts x) xs)
Just _ -> getMinimal opts xs
getOptTypes opts typ =
case opts of
[] -> [typ]
(i,opt):xs -> case opt of
Non -> (getOptTypes xs (substLabel (LVal Epsilon) (LVar (LabVar i)) typ))
++ (getOptTypes xs (substLabel (LVal Nbr) (LVar (LabVar i)) typ))
LVal Nbr -> getOptTypes xs (substLabel (LVal Nbr) (LVar (LabVar i)) typ)
LVal Epsilon -> getOptTypes xs (substLabel (LVal Epsilon) (LVar (LabVar i)) typ)
getOptimal :: Typ -> [(Int,Label)]
getOptimal tau = resolveConflicts (List.sortBy (compare `on` fst) (getOpt True tau))
where resolveConflicts l =
case l of
[] -> []
[x] -> [x]
(i,l1):(j,l2):xs -> if i == j then
if l1 /= l2 then resolveConflicts ((i,Non):xs)
else resolveConflicts ((i,l1):xs)
else (i,l1):(resolveConflicts ((j,l2):xs))
getOpt :: Bool -> Typ -> [(Int,Label)]
getOpt min tau =
case tau of
TVar _ -> []
TArrow (LVar (LabVar i)) tau1 tau2 -> ((i,opt):(getOpt (not min) tau1)) ++ (getOpt min tau2)
TArrow _ tau1 tau2 -> (getOpt (not min) tau1) ++ (getOpt min tau2)
TAll (LVar (LabVar i)) _ tau1 -> (i,notOpt):(getOpt min tau1)
TAll _ _ tau1 -> getOpt min tau1
TList tau1 -> getOpt min tau1
TInt -> []
TBool -> []
where
opt = if min == False then LVal Nbr else LVal Epsilon
notOpt = if min == True then LVal Nbr else LVal Epsilon