{-# LANGUAGE FlexibleContexts #-} ----------------------------------------------------------------------------- -- Copyright 2019, Advise-Me project team. This file is distributed under -- the terms of the Apache License 2.0. For more information, see the files -- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution. ----------------------------------------------------------------------------- -- | -- Maintainer : bastiaan.heeren@ou.nl -- Stability : provisional -- Portability : portable (depends on ghc) -- ----------------------------------------------------------------------------- module Util.Expr ( substitute, eqComAssoc, isLiteral, isSubExprOf , wildcard, wildcardsToVars, restoreWildcards ) where import Control.Applicative import Control.Monad.State import Data.Bool import Domain.Math.Expr import Ideas.Common.Rewriting.Term import Ideas.Common.View import Ideas.Utils.Uniplate import Util.Monad import qualified Data.Map as M substitute :: M.Map String Expr -> Expr -> Expr substitute m = rec where rec (Var s) = case M.lookup s m of -- no occurs check, infinite loop for recursive definitions, but bounded by time constraint. -- todo: implement occurs check. Just e | s `notElem` vars e -> rec e _ -> Var s rec e = descend rec e wildcardSymbol :: Symbol wildcardSymbol = newSymbol "wildcard" isWildcardSymbol :: Symbol -> Bool isWildcardSymbol = (== wildcardSymbol) -- a wildcard matches any expression. -- we can use multiple wildcards in a single expression, we can -- distinguish between wildcards using identifiers (as variables) wildcard :: String -> Expr wildcard s = Sym wildcardSymbol [Var s] prefixVariables :: Char -> Expr -> Expr prefixVariables s (Var v) = Var (s : v) prefixVariables _ a@(Sym s _) | isWildcardSymbol s = a prefixVariables s e = descend (prefixVariables s) e instantiateWildcards :: Expr -> Expr instantiateWildcards (Sym s [Var v]) | isWildcardSymbol s = Var ('w':v) instantiateWildcards e = descend instantiateWildcards e wildcardsToVars :: Expr -> Expr wildcardsToVars = instantiateWildcards . prefixVariables 'v' restoreWildcards :: Expr -> Expr restoreWildcards (Var ('v':v)) = Var v restoreWildcards (Var ('w':v)) = wildcard v restoreWildcards e = descend restoreWildcards e isLiteral :: Expr -> Bool isLiteral (Nat _) = True isLiteral (Negate n) = isLiteral n isLiteral (Number _) = True isLiteral _ = False eq :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => Expr -> Expr -> m Bool eq x y = do s <- get b <- eq' x y unless b (put s) return b eq' :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => Expr -> Expr -> m Bool eq' (Sym s1 _) (Sym s2 _) | isWildcardSymbol s1 && isWildcardSymbol s2 = return False eq' a b@(Sym s _) | isWildcardSymbol s = eq' b a eq' (Sym s [Var v]) b | isWildcardSymbol s = do (me,pe) <- get case (M.lookup v me, M.lookup v pe) of (Just a, _) -> eq' a b (_, Just p) | not (p b) -> return False _ -> put (M.insert v b me, pe) >> return True eq' (Nat a) (Nat b) = return (a == b) eq' (Number a) (Number b) = return (a == b) eq' (Var a) (Var b) = return (a == b) eq' a b | Just (sa,[a1,a2]) <- getFunction a , Just (sb,[b1,b2]) <- getFunction b , sa == sb , sa == plusSymbol || sa == timesSymbol --commutativity = do a1b1 <- eq' a1 b1 a2b1 <- eq' a2 b1 a1b2 <- eq' a1 b2 a2b2 <- eq' a2 b2 return ((a1b1 && a2b2) || (a1b2 && a2b1)) eq' a b | Just (sa,as) <- getFunction a , Just (sb,bs) <- getFunction b , sa == sb = and <$> zipWithM eq' as bs eq' _ _ = return False isSubExprOf :: (Expr -> Expr) -> Expr -> Expr -> State (M.Map String Expr, M.Map String (Expr -> Bool)) Bool isSubExprOf norm se e = state (\s -> maybe (False,s) (first (const True)) (runStateT (extractSum norm e se) s)) eqComAssoc :: (Expr -> Expr) -> Expr -> Expr -> State (M.Map String Expr, M.Map String (Expr -> Bool)) Bool eqComAssoc norm e se = state (\s -> maybe (False,s) (first null) (runStateT (extractSum norm e se) s)) -- @extractSum e s@ tries to extract each sum part of @s@ from @e@ and returns -- the list of remaining sum operands of @e@. extractSum :: (Expr -> Expr) -> Expr -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr] extractSum norm e s = mIf (eqOrNorm norm e s) (return []) empty <|> foldM (extractSumMember norm) (from sumView s) (from sumView e) -- (return $ from sumView s) {- do (b,xs) <- tryOrDiscardOn fst $ foldM (extractSumMember norm) (True, from sumView s) (filter (not.hasWildcard) (from sumView e) ++ filter hasWildcard (from sumView e)) if (not b) then do x <- eqOrNorm norm e s return (x,if x then [] else from sumView s) else return (b,xs) -} -- | @extractSumPart norm sv r@ tries to extract expression @r@ from @sv@. -- When @r@ is a wildcard, we first try to match the wildcard to the entire sum -- Note that there is also a possibility for the wildcard to match a subset of the sum but we disregard this. -- If there is no wildcard or it does not match, we try to see if we can find a part of the sum for which -- we can extract a product. extractSumMember :: (Expr -> Expr) -> [Expr] -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr] extractSumMember norm eSumMembers sSumMember = case sSumMember of Sym s [Var _] | isWildcardSymbol s -> mIf (eq sSumMember (to sumView eSumMembers)) (return []) noWildcardCase _ -> noWildcardCase where noWildcardCase = do xs <- deleteByM (\x y -> null <$> extractProduct norm x y) sSumMember eSumMembers guard (length xs < length eSumMembers) return xs -- return (p && length xs < length eSumMembers,xs) extractProduct :: (Expr -> Expr) -> Expr -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr] extractProduct norm e s = guard (eEven == sEven) >> foldM (extractProductMember norm b) sProductMembers eProductMembers <|> mIf (eqOrNorm norm e s) (return []) empty {- <|> [] <$ eqOrNorm norm e s -- TODO: only when eqOrNorm is True <|> return sProductMembers -} where (eEven, eProductMembers) = from productView e (sEven, sProductMembers) = from productView s b = True -- length sumX == 1 && length xProducts == 1 extractProductMember :: (Expr -> Expr) -> Bool -> [Expr] -> Expr -> StateT (M.Map String Expr, M.Map String (Expr -> Bool)) Maybe [Expr] extractProductMember norm _ sv r = do xs <- deleteByM (eqOrNorm norm) {-(if b then eqOrNorm norm else eqComAssoc norm) -} r sv guard (length xs < length sv) return xs -- return (p && length xs < length sv,xs) eqOrNorm :: MonadState (M.Map String Expr, M.Map String (Expr -> Bool)) m => (Expr -> Expr) -> Expr -> Expr -> m Bool eqOrNorm norm x y = do b <- eq x y if b then return True else eq (norm x) (norm y) -- helper mIf :: Applicative m => m Bool -> m a -> m a -> m a mIf = (flip . fmap flip) (liftA3 bool)