{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiWayIf #-} -- | Simple Constraint solver module Haskus.Utils.Solver ( -- * Oracle PredState (..) , PredOracle , makeOracle , oraclePredicates , emptyOracle , predIsSet , predIsUnset , predIsUndef , predIs , predState -- * Constraint , Constraint (..) , simplifyConstraint , constraintReduce -- * Rule , Rule (..) , orderedNonTerminal , mergeRules , evalsTo , MatchResult (..) -- * Predicated data , Predicated (..) , createPredicateTable , initP , applyP , resultP ) where import Haskus.Utils.Maybe import Haskus.Utils.Flow import Haskus.Utils.List import Haskus.Utils.Map.Strict (Map) import qualified Haskus.Utils.Map.Strict as Map import Data.Bits import Control.Arrow (first,second) import Prelude hiding (pred) ------------------------------------------------------- -- Constraint ------------------------------------------------------- -- | Predicate state data PredState = SetPred -- ^ Set predicate | UnsetPred -- ^ Unset predicate | UndefPred -- ^ Undefined predicate deriving (Show,Eq,Ord) -- | Predicate oracle type PredOracle p = Map p PredState -- | Ask an oracle if a predicate is set predIsSet :: Ord p => PredOracle p -> p -> Bool predIsSet oracle p = predIs oracle p SetPred -- | Ask an oracle if a predicate is unset predIsUnset :: Ord p => PredOracle p -> p -> Bool predIsUnset oracle p = predIs oracle p UnsetPred -- | Ask an oracle if a predicate is undefined predIsUndef :: Ord p => PredOracle p -> p -> Bool predIsUndef oracle p = predIs oracle p UndefPred -- | Check the state of a predicate predIs :: Ord p => PredOracle p -> p -> PredState -> Bool predIs oracle p s = predState oracle p == s -- | Get predicate state predState :: Ord p => PredOracle p -> p -> PredState predState oracle p = case p `Map.lookup` oracle of Just s -> s Nothing -> UndefPred -- | Create an oracle from a list makeOracle :: Ord p => [(p,PredState)] -> PredOracle p makeOracle = Map.fromList -- | Get a list of predicates from an oracle oraclePredicates :: Ord p => PredOracle p -> [(p,PredState)] oraclePredicates = filter (\(_,s) -> s /= UndefPred) . Map.toList -- | Oracle that always answer Undef emptyOracle :: PredOracle p emptyOracle = Map.empty ------------------------------------------------------- -- Constraint ------------------------------------------------------- data Constraint e p = Predicate p | Not (Constraint e p) | And [Constraint e p] | Or [Constraint e p] | Xor [Constraint e p] | CBool Bool deriving (Show,Eq,Ord) instance Functor (Constraint e) where fmap f (Predicate p) = Predicate (f p) fmap _ (CBool b) = CBool b fmap f (Not c) = Not (fmap f c) fmap f (And cs) = And (fmap (fmap f) cs) fmap f (Or cs) = Or (fmap (fmap f) cs) fmap f (Xor cs) = Xor (fmap (fmap f) cs) -- | Reduce a constraint constraintReduce :: (Ord p, Eq p, Eq e) => PredOracle p -> Constraint e p -> Constraint e p constraintReduce oracle c = case simplifyConstraint c of Predicate p -> case predState oracle p of UndefPred -> Predicate p SetPred -> CBool True UnsetPred -> CBool False Not c' -> case constraintReduce oracle c' of CBool v -> CBool (not v) c'' -> Not c'' And cs -> case fmap (constraintReduce oracle) cs of [] -> error "Empty And constraint" cs' | all (constraintIsBool True) cs' -> CBool True cs' | any (constraintIsBool False) cs' -> CBool False cs' -> case filter (not . constraintIsBool True) cs' of [c'] -> c' cs'' -> And cs'' Or cs -> case fmap (constraintReduce oracle) cs of [] -> error "Empty Or constraint" cs' | all (constraintIsBool False) cs' -> CBool False cs' | any (constraintIsBool True) cs' -> CBool True cs' -> case filter (not . constraintIsBool False) cs' of [c'] -> c' cs'' -> Or cs'' Xor cs -> case fmap (constraintReduce oracle) cs of [] -> error "Empty Xor constraint" cs' -> simplifyConstraint (Xor cs') c'@(CBool _) -> c' -- | Check that a constraint is evaluated to a given boolean value constraintIsBool :: Bool -> Constraint e p -> Bool constraintIsBool v (CBool v') = v == v' constraintIsBool _ _ = False -- | Get predicates used in a constraint getConstraintPredicates :: Constraint e p -> [p] getConstraintPredicates = \case Predicate p -> [p] Not c -> getConstraintPredicates c And cs -> concatMap getConstraintPredicates cs Or cs -> concatMap getConstraintPredicates cs Xor cs -> concatMap getConstraintPredicates cs CBool _ -> [] -- | Get constraint terminals getConstraintTerminals :: Constraint e p -> [Bool] getConstraintTerminals = \case Predicate _ -> [True,False] CBool v -> [v] Not c -> fmap not (getConstraintTerminals c) And cs -> let cs' = fmap getConstraintTerminals cs in if | null cs -> [] | any (False `elem`) cs' -> [False] | all (sing True) cs' -> [True] | otherwise -> [True,False] Or cs -> let cs' = fmap getConstraintTerminals cs in if | null cs -> [] | any (True `elem`) cs' -> [True] | all (sing False) cs' -> [False] | otherwise -> [True,False] Xor cs -> let cs' = fmap getConstraintTerminals cs in if | null cs -> [] | otherwise -> xo False cs' where xo t [] = [t] xo False ([True]:xs) = xo True xs xo True ([True]:_) = [False] xo False ([False]:xs) = xo False xs xo True ([False]:xs) = xo True xs xo _ ([]:_) = [] xo _ _ = [True,False] sing v [v'] = v == v' sing _ _ = False ------------------------------------------------------- -- Rule ------------------------------------------------------- data Rule e p a = Terminal a | NonTerminal [(Constraint e p, Rule e p a)] | Fail e deriving (Show,Eq,Ord) instance Functor (Rule e p) where fmap f (Terminal a) = Terminal (f a) fmap f (NonTerminal xs) = NonTerminal (fmap (second (fmap f)) xs) fmap _ (Fail e) = Fail e -- | NonTerminal whose constraints are evaluated in order -- -- Earlier constraints must be proven false for the next ones to be considered orderedNonTerminal :: [(Constraint e p, Rule e p a)] -> Rule e p a orderedNonTerminal = NonTerminal . go [] where go _ [] = [] go [] ((c,r):xs) = (simplifyConstraint c,r) : go [c] xs go cs ((c,r):xs) = (simplifyConstraint (And (c:fmap Not cs)),r) : go (c:cs) xs -- | Simplify a constraint simplifyConstraint :: Constraint e p -> Constraint e p simplifyConstraint x = case x of Predicate _ -> x CBool _ -> x Not (Predicate _) -> x Not (CBool v) -> CBool (not v) Not (Not c) -> simplifyConstraint c Not (Or cs) -> simplifyConstraint (And (fmap Not cs)) Not (And cs) -> simplifyConstraint (Or (fmap Not cs)) Not (Xor cs) -> case simplifyConstraint (Xor cs) of Xor cs' -> Not (Xor cs') r -> simplifyConstraint (Not r) And [c] -> simplifyConstraint c Or [c] -> simplifyConstraint c Xor [c] -> let c' = simplifyConstraint c in if | constraintIsBool True c' -> CBool True | constraintIsBool False c' -> CBool False | otherwise -> c' And cs -> let cs' = fmap simplifyConstraint cs in if | any (constraintIsBool False) cs' -> CBool False | all (constraintIsBool True) cs' -> CBool True | otherwise -> And cs' Or cs -> let cs' = fmap simplifyConstraint cs in if | any (constraintIsBool True) cs' -> CBool True | all (constraintIsBool False) cs' -> CBool False | otherwise -> Or cs' Xor cs -> let cs' = fmap simplifyConstraint cs countTrue = length (filter (constraintIsBool True) cs') countFalse = length (filter (constraintIsBool False) cs') countAll = length cs' in if | countTrue > 1 -> CBool False | countTrue == 1 && countTrue + countFalse == countAll -> CBool True | countAll == countFalse -> CBool False | otherwise -> Xor cs' -- | Merge two rules together mergeRules :: Rule e p a -> Rule e p b -> Rule e p (a,b) mergeRules = go where go (Fail e) _ = Fail e go _ (Fail e) = Fail e go (Terminal a) (Terminal b) = Terminal (a,b) go (Terminal a) (NonTerminal bs) = NonTerminal (fl (Terminal a) bs) go (NonTerminal as) (Terminal b) = NonTerminal (fr (Terminal b) as) go (NonTerminal as) b = NonTerminal (fr b as) fl x = fmap (second (x `mergeRules`)) fr x = fmap (second (`mergeRules` x)) -- | Reduce a rule ruleReduce :: forall e p a. ( Ord p, Eq e, Eq p, Eq a) => PredOracle p -> Rule e p a -> MatchResult e (Rule e p a) a ruleReduce oracle r = case r of Terminal a -> Match a Fail e -> MatchFail [e] NonTerminal rs -> let rs' :: [(Constraint e p, Rule e p a)] rs' = rs -- reduce constraints |> fmap (first (constraintReduce oracle)) -- filter non matching rules |> filter (not . constraintIsBool False . fst) (matchingRules,mayMatchRules) = partition (constraintIsBool True . fst) rs' matchingResults = nub $ fmap snd $ matchingRules (failingResults,terminalResults,nonTerminalResults) = go [] [] [] matchingResults go fr tr ntr = \case [] -> (fr,tr,ntr) (Fail x:xs) -> go (x:fr) tr ntr xs (Terminal x:xs) -> go fr (x:tr) ntr xs (NonTerminal x:xs) -> go fr tr (x:ntr) xs divergence = case terminalResults of -- results are already "nub"ed. -- More than 1 results => divergence (_:_:_) -> True _ -> False in case rs' of [] -> NoMatch _ | not (null failingResults) -> MatchFail failingResults | divergence -> MatchDiverge (fmap Terminal terminalResults) | not (null nonTerminalResults) -> -- fold matching nested NonTerminals ruleReduce oracle <| NonTerminal <| (fmap (\x -> (CBool True, Terminal x)) terminalResults ++ mayMatchRules ++ concat nonTerminalResults) | otherwise -> case (matchingResults,mayMatchRules) of ([Terminal a], []) -> Match a _ -> DontMatch (NonTerminal rs') -- | Get possible resulting terminals getRuleTerminals :: Rule e p a -> [a] getRuleTerminals (Fail _) = [] getRuleTerminals (Terminal a) = [a] getRuleTerminals (NonTerminal xs) = concatMap (getRuleTerminals . snd) xs -- | Get predicates used in a rule getRulePredicates :: Eq p => Rule e p a -> [p] getRulePredicates (Fail _) = [] getRulePredicates (Terminal _) = [] getRulePredicates (NonTerminal xs) = nub $ concatMap (\(x,y) -> getConstraintPredicates x ++ getRulePredicates y) xs -- | Constraint checking that a predicated value evaluates to some terminal evalsTo :: (Ord (Pred a), Eq a, Eq (PredTerm a), Eq (Pred a), Predicated a) => a -> PredTerm a -> Constraint e (Pred a) evalsTo s a = case createPredicateTable s (const True) True of Left x -> CBool (x == a) Right xs -> orConstraints <| fmap andPredicates <| fmap oraclePredicates <| fmap fst <| filter ((== a) . snd) <| xs where andPredicates [] = CBool True andPredicates [x] = makePred x andPredicates xs = And (fmap makePred xs) orConstraints [] = CBool True orConstraints [x] = x orConstraints xs = Or xs makePred (p, UnsetPred) = Not (Predicate p) makePred (p, SetPred) = Predicate p makePred (_, UndefPred) = undefined -- shouldn't be possible given we use -- get the predicates from the oracle itself ------------------------------------------------------- -- Predicated data ------------------------------------------------------- -- | Predicated data -- -- @ -- data T -- data NT -- -- type family RuleT e p a s :: * where -- RuleT e p a T = a -- RuleT e p a NT = Rule e p a -- -- data PD t = PD -- { p1 :: RuleT () Bool Int t -- , p2 :: RuleT () Bool String t -- } -- -- deriving instance Eq (PD T) -- deriving instance Show (PD T) -- deriving instance Ord (PD T) -- deriving instance Eq (PD NT) -- deriving instance Show (PD NT) -- deriving instance Ord (PD NT) -- -- -- instance Predicated (PD NT) where -- type PredErr (PD NT) = () -- type Pred (PD NT) = Bool -- type PredTerm (PD NT) = PD T -- -- liftTerminal (PD a b) = PD (liftTerminal a) (liftTerminal b) -- -- reducePredicates oracle (PD a b) = -- initP PD PD -- |> (`applyP` reducePredicates oracle a) -- |> (`applyP` reducePredicates oracle b) -- |> resultP -- -- getTerminals (PD as bs) = [ PD a b | a <- getTerminals as -- , b <- getTerminals bs -- ] -- -- getPredicates (PD a b) = concat -- [ getPredicates a -- , getPredicates b -- ] -- @ class Predicated a where -- | Error type type PredErr a :: * -- | Predicate type type Pred a :: * -- | Terminal type type PredTerm a :: * -- | Build a non terminal from a terminal liftTerminal :: PredTerm a -> a -- | Reduce predicates reducePredicates :: PredOracle (Pred a) -> a -> MatchResult (PredErr a) a (PredTerm a) -- | Get possible resulting terminals getTerminals :: a -> [PredTerm a] -- | Get used predicates getPredicates :: a -> [Pred a] instance (Ord p, Eq e, Eq a, Eq p) => Predicated (Rule e p a) where type PredErr (Rule e p a) = e type Pred (Rule e p a) = p type PredTerm (Rule e p a) = a reducePredicates = ruleReduce liftTerminal = Terminal getTerminals = getRuleTerminals getPredicates = getRulePredicates instance (Ord p, Eq e, Eq p) => Predicated (Constraint e p) where type PredErr (Constraint e p) = e type Pred (Constraint e p) = p type PredTerm (Constraint e p) = Bool reducePredicates oracle c = case constraintReduce oracle c of CBool v -> Match v c' -> DontMatch c' liftTerminal = CBool getTerminals = getConstraintTerminals getPredicates = getConstraintPredicates -- | Reduction result data MatchResult e nt t = NoMatch | Match t | DontMatch nt | MatchFail [e] | MatchDiverge [nt] deriving (Show,Eq,Ord) instance Functor (MatchResult e nt) where fmap f x = case x of NoMatch -> NoMatch MatchDiverge xs -> MatchDiverge xs MatchFail es -> MatchFail es Match a -> Match (f a) DontMatch a -> DontMatch a -- | Compose reduction results -- -- We reuse the MatchResult data type: -- * a "terminal" on the left can be used to build either a terminal or a non terminal -- * a "non terminal" on the left can only be used to build a non terminal applyP :: ( Predicated ntb ) => MatchResult e (ntb -> nt) (ntb -> nt, PredTerm ntb -> t) -> MatchResult e ntb (PredTerm ntb) -> MatchResult e nt (nt,t) applyP NoMatch _ = NoMatch applyP _ NoMatch = NoMatch applyP (MatchFail xs) (MatchFail ys) = MatchFail (xs++ys) applyP (MatchFail xs) _ = MatchFail xs applyP _ (MatchFail ys) = MatchFail ys applyP (MatchDiverge fs) (MatchDiverge ys) = MatchDiverge [f y | f <- fs, y <- ys] applyP (MatchDiverge fs) (Match b) = MatchDiverge [f (liftTerminal b) | f <- fs] applyP (MatchDiverge fs) (DontMatch b) = MatchDiverge [f b | f <- fs] applyP (DontMatch f) (MatchDiverge ys) = MatchDiverge [f y | y <- ys] applyP (DontMatch f) (DontMatch b) = DontMatch (f b) applyP (DontMatch f) (Match b) = DontMatch (f (liftTerminal b)) applyP (Match (fnt,_)) (MatchDiverge ys) = MatchDiverge [fnt y | y <- ys] applyP (Match (fnt,_)) (DontMatch b) = DontMatch (fnt b) applyP (Match (fnt,ft)) (Match b) = Match (fnt (liftTerminal b), ft b) -- | Initialise a reduction result (typically with two functions/constructors) initP :: nt -> t -> MatchResult e nt (nt,t) initP nt t = Match (nt,t) -- | Fixup result (see initP and applyP) resultP :: MatchResult e nt (nt,t) -> MatchResult e nt t resultP = fmap snd -- | Create a table of predicates that return a terminal createPredicateTable :: ( Ord (Pred a) , Eq (Pred a) , Eq a , Predicated a , Predicated a , Pred a ~ Pred a ) => a -> (PredOracle (Pred a) -> Bool) -> Bool -> Either (PredTerm a) [(PredOracle (Pred a),PredTerm a)] createPredicateTable s oracleChecker fullTable = -- we first check if the predicated value reduces to a terminal without any -- additional oracle case reducePredicates emptyOracle s of Match x -> Left x _ -> Right (mapMaybe matching oracles) where matching oracle = case reducePredicates oracle s of Match x -> Just (oracle,x) _ -> Nothing oracles = filter oracleChecker (fmap makeOracle predSets) preds = sort (getPredicates s) predSets | fullTable = makeFullSets preds | otherwise = makeSets preds [] makeFullSets ps = fmap (makeFullSet ps) ([0..2^(length ps)-1] :: [Word]) makeFullSet ps n = fmap (setB n) (ps `zip` [0..]) setB n (p,i) = if testBit n i then (p,SetPred) else (p,UnsetPred) makeSets [] os = os makeSets (p:ps) os = let ns = [(p,SetPred),(p,UnsetPred)] in makeSets ps $ concat [ [ [n] | n <- ns ] , [(n:o) | o <- os, n <- ns] , os ]