{-# LANGUAGE FlexibleInstances, LambdaCase #-} module Futhark.Analysis.AlgSimplify ( ScalExp , Error , simplify , mkSuffConds , RangesRep , ppRangesRep , linFormScalE , pickSymToElim ) where import qualified Data.Set as S import qualified Data.Map.Strict as M import Data.List (sort, sortBy, genericReplicate) import Control.Monad import Control.Monad.Reader import Control.Monad.State import Futhark.Representation.AST hiding (SDiv, SMod, SQuot, SRem, SSignum) import Futhark.Analysis.ScalExp import qualified Futhark.Representation.Primitive as P -- | Ranges are inclusive. type RangesRep = M.Map VName (Int, Maybe ScalExp, Maybe ScalExp) -- | Prettyprint a 'RangesRep'. Do not rely on the format of this -- string. Does not include the loop nesting depth information. ppRangesRep :: RangesRep -> String ppRangesRep = unlines . sort . map ppRange . M.toList where ppRange (name, (_, lower, upper)) = pretty name ++ ": " ++ if lower == upper then "== " ++ ppBound lower else "[" ++ ppBound lower ++ ", " ++ ppBound upper ++ "]" ppBound Nothing = "?" ppBound (Just se) = pretty se -- | environment recording the position and -- a list of variable-to-range bindings. data AlgSimplifyEnv = AlgSimplifyEnv { inSolveLTH0 :: Bool , ranges :: RangesRep , maxSteps :: Int -- ^ The number of -- simplifications to do before -- bailing out, to avoid spending -- too much time. } data Error = StepsExceeded | Error String type AlgSimplifyM = StateT Int (ReaderT AlgSimplifyEnv (Either Error)) runAlgSimplifier :: Bool -> AlgSimplifyM a -> RangesRep -> Either Error a runAlgSimplifier s x r = runReaderT (evalStateT x 0) env where env = AlgSimplifyEnv { inSolveLTH0 = s , ranges = r , maxSteps = 100 -- heuristically chosen } step :: AlgSimplifyM () step = do modify (1+) exceeded <- gets (>) <*> asks maxSteps when exceeded stepsExceeded stepsExceeded :: AlgSimplifyM a stepsExceeded = lift $ lift $ Left StepsExceeded badAlgSimplifyM :: String -> AlgSimplifyM a badAlgSimplifyM = lift . lift . Left . Error -- | Binds an array name to the set of used-array vars markInSolve :: AlgSimplifyEnv -> AlgSimplifyEnv markInSolve env = env { inSolveLTH0 = True } markGaussLTH0 :: AlgSimplifyM a -> AlgSimplifyM a markGaussLTH0 = local markInSolve ----------------------------------------------------------- -- A Scalar Expression, i.e., ScalExp, is simplified to: -- -- 1. if numeric: to a normalized sum-of-products form,-- -- in which on the outermost level there are N-ary -- -- Min/Max nodes, and the next two levels are a sum -- -- of products. -- -- 2. if boolean: to disjunctive normal form -- -- -- -- Corresponding Helper Representations are: -- -- 1. NNumExp, i.e., NSum of NProd of ScalExp -- -- 2. DNF -- ----------------------------------------------------------- data NNumExp = NSum [NNumExp] PrimType | NProd [ScalExp] PrimType deriving (Eq, Ord, Show) data BTerm = NRelExp RelOp0 NNumExp | LogCt !Bool | PosId VName | NegId VName deriving (Eq, Ord, Show) type NAnd = [BTerm] type DNF = [NAnd ] --type NOr = [BTerm] --type CNF = [NOr ] -- | Applies Simplification at Expression level: simplify :: ScalExp -> RangesRep -> ScalExp simplify e rangesrep = case runAlgSimplifier False (simplifyScal e) rangesrep of Left (Error err) -> error $ "Error during algebraic simplification of: " ++ pretty e ++ "\n" ++ err Left StepsExceeded -> e Right e' -> e' -- | Given a symbol i and a scalar expression e, it decomposes -- e = a*i + b and returns (a,b) if possible, otherwise Nothing. linFormScalE :: VName -> ScalExp -> RangesRep -> Either Error (Maybe (ScalExp,ScalExp)) linFormScalE i e = runAlgSimplifier False (linearFormScalExp i e) -- | Extracts sufficient conditions for a LTH0 relation to hold mkSuffConds :: ScalExp -> RangesRep -> Either Error [[ScalExp]] mkSuffConds e = runAlgSimplifier True (gaussElimRel e) {- -- | Test if Simplification engine can handle this kind of expression canSimplify :: Int -> Either Error ScalExp --[[ScalExp]] canSimplify i = do let (h,_,e2) = mkRelExp i case e2 of (RelExp LTH0 _) -> do -- let e1' = trace (pretty e1) e1 simplify e2 noLoc h -- runAlgSimplifier False (gaussAllLTH0 False S.empty =<< toNumSofP =<< simplifyScal e) noLoc h _ -> simplify e2 noLoc h-- badAlgSimplifyM "canSimplify: unimplemented!" -} ------------------------------------------------------- --- Assumes the relational expression is simplified -- --- All uses of gaussiam elimination from simplify -- --- must use simplifyNRel, which calls markGaussLTH0-- --- to set the inSolveLTH0 environment var, so that -- --- we do not enter an infinite recurssion! -- --- Returns True or False or the input replation,i.e.-- --- static-only simplification! -- ------------------------------------------------------- simplifyNRel :: BTerm -> AlgSimplifyM BTerm simplifyNRel inp_term@(NRelExp LTH0 inp_sofp) = do term <- cheapSimplifyNRel inp_term in_gauss <- asks inSolveLTH0 let tp = typeOfNAlg inp_sofp if in_gauss || isTrivialNRel term || tp `notElem` map IntType allIntTypes then return term else do ednf <- markGaussLTH0 $ gaussAllLTH0 True S.empty inp_sofp return $ case ednf of Val (BoolValue c) -> LogCt c _ -> term where isTrivialNRel (NRelExp _ (NProd [Val _] _)) = True isTrivialNRel NRelExp{} = False isTrivialNRel _ = False cheapSimplifyNRel :: BTerm -> AlgSimplifyM BTerm cheapSimplifyNRel (NRelExp rel (NProd [Val v] _)) = LogCt <$> valLTHEQ0 rel v cheapSimplifyNRel e = return e simplifyNRel inp_term = return inp_term --- TODO: handle more cases. --gaussEliminateNRel :: BTerm -> AlgSimplifyM DNF --gaussEliminateNRel _ = -- badAlgSimplifyM "gaussElimNRel: unimplemented!" gaussElimRel :: ScalExp -> AlgSimplifyM [[ScalExp]] -- ScalExp gaussElimRel (RelExp LTH0 e) = do e_sofp <- toNumSofP =<< simplifyScal e e_scal<- simplifyScal =<< gaussAllLTH0 False S.empty e_sofp e_dnf <- toDNF e_scal mapM (mapM (\case LogCt c -> return $ Val (BoolValue c) PosId i -> return $ Id i $ scalExpType e NegId i -> return $ Id i $ scalExpType e NRelExp rel ee -> RelExp rel <$> fromNumSofP ee )) e_dnf gaussElimRel _ = badAlgSimplifyM "gaussElimRel: only LTH0 Int relations please!" --ppSyms :: S.Set VName -> String --ppSyms ss = foldl (\s x -> s ++ " " ++ (baseString x)) "ElimSyms: " (S.toList ss) primScalExpLTH0 :: ScalExp -> Bool primScalExpLTH0 (Val (IntValue v)) = P.intToInt64 v < 0 primScalExpLTH0 _ = False ----------------------------------------------------------- ----------------------------------------------------------- ----------------------------------------------------------- ---`gaussAllLTH' --- --- `static_only':whether only a True/False answer is --- --- required or actual a sufficient cond--- --- `el_syms': the list of already eliminated --- --- symbols, initialy empty --- --- `sofp': the expression e in sum-of-product --- --- form that is compared to 0, --- --- i.e., e < 0. sofp assumed simplified--- --- Result: is a ScalExp expression, which is --- --- actually a predicate in DNF form, --- --- that is a sufficient condition --- --- for e < 0! --- --- --- --- gaussAllLTH0 is implementing the tracking of Min/Max--- --- terms, and uses `gaussOneDefaultLTH0' --- --- to implement gaussian-like elimination --- --- to solve the a*i + b < 0 problem. --- --- --- --- IMPORTANT: before calling gaussAllLTH0 from outside --- --- make sure to set insideSolveLTH0 env --- --- member to True, via markGaussLTH0; --- --- otherwise infinite recursion might happen--- --- w.r.t. `simplifyScal' --- ----------------------------------------------------------- ----------------------------------------------------------- ----------------------------------------------------------- type Prod = [ScalExp] gaussAllLTH0 :: Bool -> S.Set VName -> NNumExp -> AlgSimplifyM ScalExp gaussAllLTH0 static_only el_syms sofp = do step let tp = typeOfNAlg sofp rangesrep <- asks ranges e_scal <- fromNumSofP sofp let mi = pickSymToElim rangesrep el_syms e_scal case mi of Nothing -> return $ if primScalExpLTH0 e_scal then Val (BoolValue True) else RelExp LTH0 e_scal Just i -> do (jmm, fs0, terms) <- findMinMaxTerm i sofp -- i.e., sofp == fs0 * jmm + terms, where -- i appears in jmm and jmm = MinMax ... fs <- if not (null fs0) then return fs0 else do one <- getPos1 tp; return [Val one] case jmm of ------------------------------------------------------------------------ -- A MinMax expression which uses to-be-eliminated symbol i was found -- ------------------------------------------------------------------------ Just (MaxMin _ [] ) -> badAlgSimplifyM "gaussAllLTH0: Empty MinMax Node!" Just (MaxMin ismin mmts) -> do mone <- getNeg1 tp -- fs_lth0 => fs < 0 -- fs_lth0 <- if null fs then return $ Val (BoolValue False) -- else gaussAllLTH0 static_only el_syms (NProd fs tp) fsm1 <- toNumSofP =<< simplifyScal =<< fromNumSofP ( NSum [NProd fs tp, NProd [Val mone] tp] tp ) fs_leq0 <- gaussAllLTH0 static_only el_syms fsm1 -- fs <= 0 -- mfsm1 = - fs - 1, fs_geq0 => (fs >= 0), -- i.e., fs_geq0 => (-fs - 1 < 0) mfsm1 <- toNumSofP =<< simplifyScal =<< fromNumSofP ( NSum [NProd (Val mone:fs) tp,NProd [Val mone] tp] tp ) fs_geq0 <- gaussAllLTH0 static_only el_syms mfsm1 -- mm_terms are the simplified terms of the MinMax obtained -- after inlining everything inside the MinMax, i.e., intially -- terms + fs * MinMax ismin [t1,..,tn] -> [fs*t1+terms, ..., fs*tn+terms] mm_terms<- mapM (\t -> toNumSofP =<< simplifyScal =<< fromNumSofP (NSum ( NProd (t:fs) tp:terms ) tp) ) mmts -- for every (simplified) `term_i' of the inline MinMax exp, -- get the sufficient conditions for `term_i < 0' mms <- mapM (gaussAllLTH0 static_only el_syms) mm_terms if static_only -------------------------------------------------------------------- -- returns either Val (BoolValue True) or the original ScalExp relat -- -------------------------------------------------------------------- then if ( fs_geq0 == Val (BoolValue True) && ismin) || ( fs_leq0 == Val (BoolValue True) && not ismin) -- at least one term should be < 0! then do let is_one_true = Val (BoolValue True ) `elem` mms let are_all_false = all (== Val (BoolValue False)) mms return $ if is_one_true then Val (BoolValue True) else if are_all_false then Val (BoolValue False) else RelExp LTH0 e_scal -- otherwise all terms should be all true! else do let are_all_true = all (== Val (BoolValue True )) mms let is_one_false = Val (BoolValue False) `elem` mms return $ if are_all_true then Val (BoolValue True ) else if is_one_false then Val (BoolValue False) else RelExp LTH0 e_scal -------------------------------------------------------------------- -- returns sufficient conditions for the ScalExp relation to hold -- -------------------------------------------------------------------- else do let mm_fsgeq0 = foldl (if ismin then SLogOr else SLogAnd) (Val (BoolValue (not ismin))) mms let mm_fslth0 = foldl (if ismin then SLogAnd else SLogOr) (Val (BoolValue ismin )) mms -- the sufficient condition for the original expression, e.g., -- terms + fs * Min [t1,..,tn] < 0 is -- (fs >= 0 && (fs*t_1+terms < 0 || ... || fs*t_n+terms < 0) ) || -- (fs < 0 && (fs*t_1+terms < 0 && ... && fs*t_n+terms < 0) ) return $ SLogOr (SLogAnd fs_geq0 mm_fsgeq0) (SLogAnd fs_leq0 mm_fslth0) Just _ -> badAlgSimplifyM "gaussOneLTH0: (Just MinMax) invariant violated!" ------------------------------------------------------------------------ -- A MinMax expression which uses (to-be-elim) symbol i was NOT found -- ------------------------------------------------------------------------ Nothing-> do m_sofp <- gaussOneDefaultLTH0 static_only i el_syms sofp case m_sofp of Nothing -> gaussAllLTH0 static_only (S.insert i el_syms) sofp Just res_eofp -> return res_eofp where findMinMaxTerm :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, Prod, [NNumExp]) findMinMaxTerm _ (NSum [] _) = return (Nothing, [], []) findMinMaxTerm _ (NSum [NProd [MaxMin ismin e] _] _) = return (Just (MaxMin ismin e), [], []) findMinMaxTerm _ (NProd [MaxMin ismin e] _) = return (Just (MaxMin ismin e), [], []) findMinMaxTerm ii t@NProd{} = do (mm, fs) <- findMinMaxFact ii t return (mm, fs, []) findMinMaxTerm ii (NSum (t:ts) tp)= do rangesrep <- asks ranges case M.lookup ii rangesrep of Just (_, Just _, Just _) -> do f <- findMinMaxFact ii t case f of (Just mm, fs) -> return (Just mm, fs, ts) (Nothing, _ ) -> do (mm, fs', ts') <- findMinMaxTerm ii (NSum ts tp) return (mm, fs', t:ts') _ -> return (Nothing, [], t:ts) findMinMaxFact :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, Prod) findMinMaxFact _ (NProd [] _ ) = return (Nothing, []) findMinMaxFact ii (NProd (f:fs) tp) = case f of MaxMin ismin ts -> do let id_set = mconcat $ map freeIn ts if ii `nameIn` id_set then return (Just (MaxMin ismin ts), fs) else do (mm, fs') <- findMinMaxFact ii (NProd fs tp) return (mm, f:fs') _ -> do (mm, fs') <- findMinMaxFact ii (NProd fs tp) return (mm, f:fs') findMinMaxFact ii (NSum [f] _) = findMinMaxFact ii f findMinMaxFact _ (NSum _ _) = badAlgSimplifyM "findMinMaxFact: NSum argument illegal!" gaussOneDefaultLTH0 :: Bool -> VName -> S.Set VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp) gaussOneDefaultLTH0 static_only i elsyms e = do aipb <- linearForm i e case aipb of Nothing -> return Nothing Just (a, b) -> do rangesrep <- asks ranges one <- getPos1 (typeOfNAlg e) ascal <- fromNumSofP a mam1 <- toNumSofP =<< simplifyScal (SNeg (SPlus ascal (Val one))) am1 <- toNumSofP =<< simplifyScal (SMinus ascal (Val one)) ma <- toNumSofP =<< simplifyScal (SNeg ascal) b_scal<- fromNumSofP b mbm1 <- toNumSofP =<< simplifyScal (SNeg (SPlus b_scal (Val one))) aleq0 <- simplifyScal =<< gaussAllLTH0 static_only elsyms am1 ageq0 <- simplifyScal =<< gaussAllLTH0 static_only elsyms mam1 case M.lookup i rangesrep of Nothing -> badAlgSimplifyM "gaussOneDefaultLTH0: sym not in ranges!" Just (_, Nothing, Nothing) -> badAlgSimplifyM "gaussOneDefaultLTH0: both bounds are undefined!" -- only the lower-bound of i is known! Just (_, Just lb, Nothing) -> do alpblth0 <- gaussElimHalf static_only elsyms lb a b and_half <- simplifyScal alpblth0 case (and_half, aleq0) of (Val (BoolValue True), Val (BoolValue True)) -> return $ Just and_half _ -> do malmbm1lth0 <- gaussElimHalf static_only elsyms lb ma mbm1 other_half <- simplifyScal malmbm1lth0 case (other_half, ageq0) of (Val (BoolValue True), Val (BoolValue True)) -> return $ Just (Val (BoolValue False)) _ -> return Nothing Just (_, Nothing, Just ub) -> do aupblth0 <- gaussElimHalf static_only elsyms ub a b and_half <- simplifyScal aupblth0 case (and_half, ageq0) of (Val (BoolValue True), Val (BoolValue True)) -> return $ Just and_half _ -> do maumbm1 <- gaussElimHalf static_only elsyms ub ma mbm1 other_half <- simplifyScal maumbm1 case (other_half, aleq0) of (Val (BoolValue True), Val (BoolValue True)) -> return $ Just (Val (BoolValue False)) _ -> return Nothing Just (_, Just lb, Just ub) -> if static_only then if aleq0 == Val (BoolValue True) then do alpblth0 <- simplifyScal =<< gaussElimHalf static_only elsyms lb a b if alpblth0 == Val (BoolValue True) then return $ Just (Val (BoolValue True)) else do maubmbm1 <- simplifyScal =<< gaussElimHalf static_only elsyms ub ma mbm1 return $ if maubmbm1 == Val (BoolValue True) then Just (Val (BoolValue False)) else Nothing else if ageq0 == Val (BoolValue True) then do aupblth0 <- simplifyScal =<< gaussElimHalf static_only elsyms ub a b if aupblth0 == Val (BoolValue True) then return $ Just (Val (BoolValue True)) else do malbmbm1 <- simplifyScal =<< gaussElimHalf static_only elsyms lb ma mbm1 return $ if malbmbm1 == Val (BoolValue True) then Just (Val (BoolValue False)) else Nothing else return Nothing else do alpblth0 <- gaussElimHalf static_only elsyms lb a b aupblth0 <- gaussElimHalf static_only elsyms ub a b res <- simplifyScal $ SLogOr (SLogAnd aleq0 alpblth0) (SLogAnd ageq0 aupblth0) return $ Just res where gaussElimHalf :: Bool -> S.Set VName -> ScalExp -> NNumExp -> NNumExp -> AlgSimplifyM ScalExp gaussElimHalf only_static elsyms0 q a b = do a_scal <- fromNumSofP a b_scal <- fromNumSofP b e_num_scal <- simplifyScal (SPlus (STimes a_scal q) b_scal) e_num <- toNumSofP e_num_scal gaussAllLTH0 only_static elsyms0 e_num -- pos <- asks pos -- badAlgSimplifyM "gaussOneDefaultLTH0: unimplemented!" ---------------------------------------------------------- --- Pick a Symbol to Eliminate & Bring To Linear Form --- ---------------------------------------------------------- pickSymToElim :: RangesRep -> S.Set VName -> ScalExp -> Maybe VName pickSymToElim rangesrep elsyms0 e_scal = -- ranges <- asks ranges -- e_scal <- fromNumSofP e0 let ids0= namesToList $ freeIn e_scal ids1= filter (\s -> not (S.member s elsyms0)) ids0 ids2= filter (\s -> case M.lookup s rangesrep of Nothing -> False Just _ -> True ) ids1 ids = sortBy (\n1 n2 -> let n1p = M.lookup n1 rangesrep n2p = M.lookup n2 rangesrep in case (n1p, n2p) of (Just (p1,_,_), Just (p2,_,_)) -> compare (-p1) (-p2) (_ , _ ) -> compare (1::Int) (1::Int) ) ids2 in case ids of [] -> Nothing v:_ -> Just v linearFormScalExp :: VName -> ScalExp -> AlgSimplifyM (Maybe (ScalExp, ScalExp)) linearFormScalExp sym scl_exp = do sofp <- toNumSofP =<< simplifyScal scl_exp ab <- linearForm sym sofp case ab of Just (a_sofp, b_sofp) -> do a <- fromNumSofP a_sofp b <- fromNumSofP b_sofp a'<- simplifyScal a b'<- simplifyScal b return $ Just (a', b') Nothing -> return Nothing linearForm :: VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp)) linearForm _ (NProd [] _) = badAlgSimplifyM "linearForm: empty Prod!" linearForm idd ee@NProd{} = linearForm idd (NSum [ee] (typeOfNAlg ee)) linearForm _ (NSum [] _) = badAlgSimplifyM "linearForm: empty Sum!" linearForm idd (NSum terms tp) = do terms_d_idd <- mapM (\t -> do t0 <- case t of NProd (_:_) _ -> return t _ -> badAlgSimplifyM "linearForm: ILLEGAL111!!!!" t_scal <- fromNumSofP t0 simplifyScal $ SDiv t_scal (Id idd (scalExpType t_scal)) ) terms let myiota = [1..(length terms)] let ia_terms= filter (\(_,t)-> case t of SDiv _ _ -> False _ -> True ) (zip myiota terms_d_idd) let (a_inds, a_terms) = unzip ia_terms let (_, b_terms) = unzip $ filter (\(iii,_) -> iii `notElem` a_inds) (zip myiota terms) -- check that b_terms do not contain idd b_succ <- foldM (\acc x -> case x of NProd fs _ -> do let fs_scal = case fs of [] -> Val $ IntValue $ Int32Value 1 f:fs' -> foldl STimes f fs' let b_ids = freeIn fs_scal return $ acc && not (idd `nameIn` b_ids) _ -> badAlgSimplifyM "linearForm: ILLEGAL222!!!!" ) True b_terms case a_terms of t:ts | b_succ -> do let a_scal = foldl SPlus t ts a_terms_sofp <- toNumSofP =<< simplifyScal a_scal b_terms_sofp <- if null b_terms then do zero <- getZero tp; return $ NProd [Val zero] tp else return $ NSum b_terms tp return $ Just (a_terms_sofp, b_terms_sofp) _ -> return Nothing ------------------------------------------------ ------------------------------------------------ -- Main Helper Function: takes a scalar exp, -- -- normalizes and simplifies it -- ------------------------------------------------ ------------------------------------------------ simplifyScal :: ScalExp -> AlgSimplifyM ScalExp simplifyScal (Val v) = return $ Val v simplifyScal (Id x t) = return $ Id x t simplifyScal e@SNot{} = fromDNF =<< simplifyDNF =<< toDNF e simplifyScal e@SLogAnd{} = fromDNF =<< simplifyDNF =<< toDNF e simplifyScal e@SLogOr{} = fromDNF =<< simplifyDNF =<< toDNF e simplifyScal e@RelExp{} = fromDNF =<< simplifyDNF =<< toDNF e -------------------------------------- --- MaxMin related simplifications --- -------------------------------------- simplifyScal (MaxMin _ []) = badAlgSimplifyM "Scalar MaxMin expression with empty arglist." simplifyScal (MaxMin _ [e]) = simplifyScal e simplifyScal (MaxMin ismin es) = do -- helperMinMax ismin es pos -- pos <- asks pos es0 <- mapM simplifyScal es let evals = filter isValue es0 es' = filter (not . isValue) es0 mvv = case evals of [] -> Nothing v:vs -> let myop = if ismin then min else max myval= getValue v oneval = (foldl myop myval . map getValue) vs in Just $ Val oneval -- flatten the result and remove duplicates, -- e.g., Max(Max(e1,e2), e1) -> Max(e1,e2,e3) case (es', mvv) of ([], Just vv) -> return vv (_, Just vv) -> return $ MaxMin ismin $ remDups $ foldl flatop [] $ vv:es' (_, Nothing) -> return $ MaxMin ismin $ remDups $ foldl flatop [] es' -- ToDo: This can prove very expensive as compile time -- but, IF e2-e1 <= 0 simplifies to True THEN -- Min(e1,e2) = e2. Code example: -- e1me2 <- if isMin -- then simplifyScal $ AlgOp MINUS e1 e2 pos -- else simplifyScal $ AlgOp MINUS e2 e1 pos -- e1me2leq0 <- simplifyNRel $ NRelExp LEQ0 e1me2 pos -- case e1me2leq0 of -- NAnd [LogCt True _] _ -> simplifyAlgN e1 -- NAnd [LogCt False _] _ -> simplifyAlgN e2 where isValue :: ScalExp -> Bool isValue e = case e of Val _ -> True _ -> False getValue :: ScalExp -> PrimValue getValue se = case se of Val v -> v _ -> value (0::Int32) flatop :: [ScalExp] -> ScalExp -> [ScalExp] flatop a e@(MaxMin ismin' ses) = a ++ if ismin == ismin' then ses else [e] flatop a e = a++[e] remDups :: [ScalExp] -> [ScalExp] remDups l = S.toList (S.fromList l) --------------------------------------------------- --- Plus/Minus related simplifications --- --- BUG: the MinMax pattern matching should --- --- be performed on the simplified subexps --- --------------------------------------------------- simplifyScal (SPlus e1o e2o) = do e1' <- simplifyScal e1o e2' <- simplifyScal e2o if isMaxMin e1' || isMaxMin e2' then helperPlusMinMax $ SPlus e1' e2' else normalPlus e1' e2' where normalPlus :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp normalPlus e1 e2 = do e1' <- toNumSofP e1 e2' <- toNumSofP e2 let tp = scalExpType e1 let terms = getTerms e1' ++ getTerms e2' splittedTerms <- mapM splitTerm terms let sortedTerms = sortBy (\(n1,_) (n2,_) -> compare n1 n2) splittedTerms -- foldM discriminate: adds together identical terms, and -- we reverse the list, to keep it in a ascending order. merged <- reverse <$> foldM discriminate [] sortedTerms let filtered = filter (\(_,v) -> not $ zeroIsh v ) merged if null filtered then do zero <- getZero tp fromNumSofP $ NProd [Val zero] tp else do terms' <- mapM joinTerm filtered fromNumSofP $ NSum terms' tp simplifyScal (SMinus e1 e2) = do let tp = scalExpType e1 if e1 == e2 then Val <$> getZero tp else do min_1 <- getNeg1 $ scalExpType e1 simplifyScal $ SPlus e1 $ STimes (Val min_1) e2 simplifyScal (SNeg e) = do negOne <- getNeg1 $ scalExpType e simplifyScal $ STimes (Val negOne) e simplifyScal (SAbs e) = return $ SAbs e simplifyScal (SSignum e) = return $ SSignum e --------------------------------------------------- --- Times related simplifications --- --- BUG: the MinMax pattern matching should --- --- be performed on the simplified subexps --- --------------------------------------------------- simplifyScal (STimes e1o e2o) = do e1'' <- simplifyScal e1o e2'' <- simplifyScal e2o if isMaxMin e1'' || isMaxMin e2'' then helperMultMinMax $ STimes e1'' e2'' else normalTimes e1'' e2'' where normalTimes :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp normalTimes e1 e2 = do let tp = scalExpType e1 e1' <- toNumSofP e1 e2' <- toNumSofP e2 case (e1', e2') of (NProd xs _, y@NProd{}) -> fromNumSofP =<< makeProds xs y (NProd xs _, y) -> do prods <- mapM (makeProds xs) $ getTerms y fromNumSofP $ NSum (sort prods) tp (x, NProd ys _) -> do prods <- mapM (makeProds ys) $ getTerms x fromNumSofP $ NSum (sort prods) tp (NSum xs _, NSum ys _) -> do xsMultChildren <- mapM getMultChildren xs prods <- mapM (\x -> mapM (makeProds x) ys) xsMultChildren fromNumSofP $ NSum (sort $ concat prods) tp makeProds :: [ScalExp] -> NNumExp -> AlgSimplifyM NNumExp makeProds [] _ = badAlgSimplifyM " In simplifyAlgN, makeProds: 1st arg is the empty list! " makeProds _ (NProd [] _) = badAlgSimplifyM " In simplifyAlgN, makeProds: 2nd arg is the empty list! " makeProds _ NSum{} = badAlgSimplifyM " In simplifyAlgN, makeProds: e1 * e2: e2 is a sum of sums! " makeProds (Val v1:exs) (NProd (Val v2:ys) tp1) = do v <- mulVals v1 v2 return $ NProd (Val v : sort (ys++exs) ) tp1 makeProds (Val v:exs) (NProd ys tp1) = return $ NProd (Val v : sort (ys++exs) ) tp1 makeProds exs (NProd (Val v : ys) tp1) = return $ NProd (Val v : sort (ys++exs) ) tp1 makeProds exs (NProd ys tp1) = return $ NProd (sort (ys++exs)) tp1 --------------------------------------------------- --------------------------------------------------- --- DIvide related simplifications --- --------------------------------------------------- --------------------------------------------------- simplifyScal (SDiv e1o e2o) = do e1' <- simplifyScal e1o e2' <- simplifyScal e2o if isMaxMin e1' || isMaxMin e2' then helperMultMinMax $ SDiv e1' e2' else normalFloatDiv e1' e2' where normalFloatDiv :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp normalFloatDiv e1 e2 | e1 == e2 = do one <- getPos1 $ scalExpType e1 return $ Val one -- | e1 == (negateSimplified e2) = do mone<- getNeg1 $ scalExpType e1 -- return $ Val mone | otherwise = do e1' <- toNumSofP e1 e2' <- toNumSofP e2 case e2' of NProd fs tp -> do e1Split <- mapM splitTerm (getTerms e1') case e1Split of [] -> Val <$> getZero tp _ -> do (fs', e1Split') <- trySimplifyDivRec fs [] e1Split if length fs' == length fs then turnBackAndDiv e1' e2' -- insuccess else do terms_e1' <- mapM joinTerm e1Split' e1'' <- fromNumSofP $ NSum terms_e1' tp case fs' of [] -> return e1'' _ -> do e2'' <- fromNumSofP $ NProd fs' tp return $ SDiv e1'' e2'' _ -> turnBackAndDiv e1' e2' turnBackAndDiv :: NNumExp -> NNumExp -> AlgSimplifyM ScalExp turnBackAndDiv ee1 ee2 = do ee1' <- fromNumSofP ee1 ee2' <- fromNumSofP ee2 return $ SDiv ee1' ee2' simplifyScal (SMod e1o e2o) = SMod <$> simplifyScal e1o <*> simplifyScal e2o simplifyScal (SQuot e1o e2o) = SQuot <$> simplifyScal e1o <*> simplifyScal e2o simplifyScal (SRem e1o e2o) = SRem <$> simplifyScal e1o <*> simplifyScal e2o --------------------------------------------------- --------------------------------------------------- --- Power related simplifications --- --------------------------------------------------- --------------------------------------------------- -- cannot handle 0^a, because if a < 0 then it's an error. -- Could be extented to handle negative exponents better, if needed simplifyScal (SPow e1 e2) = do let tp = scalExpType e1 e1' <- simplifyScal e1 e2' <- simplifyScal e2 if isCt1 e1' || isCt0 e2' then Val <$> getPos1 tp else if isCt1 e2' then return e1' else case (e1', e2') of (Val v1, Val v2) | Just v <- powVals v1 v2 -> return $ Val v (_, Val (IntValue n)) -> if P.intToInt64 n >= 1 then -- simplifyScal =<< fromNumSofP $ NProd (replicate n e1') tp do new_e <- fromNumSofP $ NProd (genericReplicate (P.intToInt64 n) e1') tp simplifyScal new_e else return $ SPow e1' e2' (_, _) -> return $ SPow e1' e2' where powVals :: PrimValue -> PrimValue -> Maybe PrimValue powVals (IntValue v1) (IntValue v2) = IntValue <$> P.doPow v1 v2 powVals _ _ = Nothing ----------------------------------------------------- --- Helpers for simplifyScal: MinMax related, etc --- ----------------------------------------------------- isMaxMin :: ScalExp -> Bool isMaxMin MaxMin{} = True isMaxMin _ = False helperPlusMinMax :: ScalExp -> AlgSimplifyM ScalExp helperPlusMinMax (SPlus (MaxMin ismin es) e) = simplifyScal $ MaxMin ismin $ map (`SPlus` e) es helperPlusMinMax (SPlus e (MaxMin ismin es)) = simplifyScal $ MaxMin ismin $ map (SPlus e) es helperPlusMinMax _ = badAlgSimplifyM "helperPlusMinMax: Reached unreachable case!" {- helperMinusMinMax :: ScalExp -> AlgSimplifyM ScalExp helperMinusMinMax (SMinus (MaxMin ismin es) e) = simplifyScal $ MaxMin ismin $ map (\x -> SMinus x e) es helperMinusMinMax (SMinus e (MaxMin ismin es)) = simplifyScal $ MaxMin ismin $ map (\x -> SMinus e x) es helperMinusMinMax _ = do pos <- asks pos badAlgSimplifyM "helperMinusMinMax: Reached unreachable case!" -} helperMultMinMax :: ScalExp -> AlgSimplifyM ScalExp helperMultMinMax (STimes e em@MaxMin{}) = helperTimesDivMinMax True True em e helperMultMinMax (STimes em@MaxMin{} e) = helperTimesDivMinMax True False em e helperMultMinMax (SDiv e em@MaxMin{}) = helperTimesDivMinMax False True em e helperMultMinMax (SDiv em@MaxMin{} e) = helperTimesDivMinMax False False em e helperMultMinMax _ = badAlgSimplifyM "helperMultMinMax: Reached unreachable case!" helperTimesDivMinMax :: Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp helperTimesDivMinMax isTimes isRev emo@MaxMin{} e = do em <- simplifyScal emo case em of MaxMin ismin es -> do e' <- simplifyScal e e'_sop <- toNumSofP e' p' <- simplifyNRel $ NRelExp LTH0 e'_sop case p' of LogCt ctbool -> do -- let cond = not isTimes && isRev -- let cond'= if ctbool then cond else not cond -- let ismin'= if cond' then ismin else not ismin let cond = ( isTimes && not ctbool ) || ( not isTimes && not isRev && not ctbool ) || ( not isTimes && isRev && ctbool ) let ismin' = if cond then ismin else not ismin simplifyScal $ MaxMin ismin' $ map (`mkTimesDiv` e') es _ -> if not isTimes then return $ mkTimesDiv em e' else -- e' * MaxMin{...} case e'_sop of NProd _ _ -> return $ mkTimesDiv em e' -- simplifyScal =<< fromNumSofP (NProd (em:fs) tp) NSum ts tp -> do new_ts <- mapM (\case NProd fs _ -> return $ NProd (em:fs) tp _ -> badAlgSimplifyM "helperTimesDivMinMax: SofP invariant violated!" ) ts simplifyScal =<< fromNumSofP ( NSum new_ts tp ) _ -> simplifyScal $ mkTimesDiv em e where mkTimesDiv :: ScalExp -> ScalExp -> ScalExp mkTimesDiv e1 e2 | not isTimes = if isRev then SDiv e2 e1 else SDiv e1 e2 | isRev = STimes e2 e1 | otherwise = STimes e1 e2 helperTimesDivMinMax _ _ _ _ = badAlgSimplifyM "helperTimesDivMinMax: Reached unreachable case!" --------------------------------------------------- --------------------------------------------------- --- translating to and simplifying the --- --- disjunctive normal form: toDNF, simplifyDNF --- --------------------------------------------------- --------------------------------------------------- --isTrueDNF :: DNF -> Bool --isTrueDNF [[LogCt True]] = True --isTrueDNF _ = False -- --getValueDNF :: DNF -> Maybe Bool --getValueDNF [[LogCt True]] = Just True --getValueDNF [[LogCt False]] = Just True --getValueDNF _ = Nothing negateBTerm :: BTerm -> AlgSimplifyM BTerm negateBTerm (LogCt v) = return $ LogCt (not v) negateBTerm (PosId i) = return $ NegId i negateBTerm (NegId i) = return $ PosId i negateBTerm (NRelExp rel e) = do let tp = typeOfNAlg e case (tp, rel) of (IntType it, LTH0) -> do se <- fromNumSofP e ne <- toNumSofP =<< simplifyScal (SNeg $ SPlus se (Val (value (P.intValue it (1::Int))))) return $ NRelExp LTH0 ne _ -> NRelExp (if rel == LEQ0 then LTH0 else LEQ0) <$> (toNumSofP =<< negateSimplified =<< fromNumSofP e) bterm2ScalExp :: BTerm -> AlgSimplifyM ScalExp bterm2ScalExp (LogCt v) = return $ Val $ BoolValue v bterm2ScalExp (PosId i) = return $ Id i int32 bterm2ScalExp (NegId i) = return $ SNot $ Id i int32 bterm2ScalExp (NRelExp rel e) = RelExp rel <$> fromNumSofP e -- translates from DNF to ScalExp fromDNF :: DNF -> AlgSimplifyM ScalExp fromDNF [] = badAlgSimplifyM "fromDNF: empty DNF!" fromDNF (t:ts) = do t' <- translFact t foldM (\acc x -> do x' <- translFact x; return $ SLogOr x' acc) t' ts where translFact [] = badAlgSimplifyM "fromDNF, translFact empty DNF factor!" translFact (f:fs) = do f' <- bterm2ScalExp f foldM (\acc x -> do x' <- bterm2ScalExp x; return $ SLogAnd x' acc) f' fs -- translates (and simplifies numeric expressions?) to DNF form. toDNF :: ScalExp -> AlgSimplifyM DNF toDNF (Val (BoolValue v)) = return [[LogCt v]] toDNF (Id idd _ ) = return [[PosId idd]] toDNF (RelExp rel e ) = do let t = scalExpType e case t of IntType it -> do e' <- if rel == LEQ0 then do m1 <- getNeg1 $ IntType it return $ SPlus e $ Val m1 else return e ne <- toNumSofP =<< simplifyScal e' nrel <- simplifyNRel $ NRelExp LTH0 ne -- False return [[nrel]] _ -> do ne <- toNumSofP =<< simplifyScal e nrel <- markGaussLTH0 $ simplifyNRel $ NRelExp rel ne return [[nrel]] -- toDNF (SNot (SNot e)) = toDNF e toDNF (SNot (Val (BoolValue v))) = return [[LogCt $ not v]] toDNF (SNot (Id idd _)) = return [[NegId idd]] toDNF (SNot (RelExp rel e)) = do let not_rel = if rel == LEQ0 then LTH0 else LEQ0 neg_e <- simplifyScal (SNeg e) toDNF $ RelExp not_rel neg_e -- toDNF (SLogOr e1 e2 ) = do e1s <- toDNF e1 e2s <- toDNF e2 return $ sort $ e1s ++ e2s toDNF (SLogAnd e1 e2 ) = do -- [t1 ++ t2 | t1 <- toDNF e1, t2 <- toDNF e2] e1s <- toDNF e1 e2s <- toDNF e2 let lll = map (\t2-> map (++t2) e1s) e2s return $ sort $ concat lll toDNF (SNot (SLogAnd e1 e2)) = do e1s <- toDNF (SNot e1) e2s <- toDNF (SNot e2) return $ sort $ e1s ++ e2s toDNF (SNot (SLogOr e1 e2)) = do -- [t1 ++ t2 | t1 <- dnf $ SNot e1, t2 <- dnf $ SNot e2] e1s <- toDNF $ SNot e1 e2s <- toDNF $ SNot e2 let lll = map (\t2-> map (++t2) e1s) e2s return $ sort $ concat lll toDNF _ = badAlgSimplifyM "toDNF: not a boolean expression!" ------------------------------------------------------ --- Simplifying Boolean Expressions: --- --- 0. p AND p == p; p OR p == p --- --- 1. False AND p == False; True OR p == True --- --- 2. True AND p == p; False OR p == p --- --- 3.(not p)AND p == FALSE; (not p)OR p == True --- --- 4. ToDo: p1 AND p2 == p1 if p1 => p2 --- --- p1 AND p2 == False if p1 => not p2 or--- --- p2 => not p1 --- --- Also: p1 OR p2 == p2 if p1 => p2 --- --- p1 OR p2 == True if not p1 => p2 or --- --- not p2 => p1 --- --- This boils down to relations: --- --- e1 < 0 => e2 < 0 if e2 <= e1 --- ------------------------------------------------------ simplifyDNF :: DNF -> AlgSimplifyM DNF simplifyDNF terms0 = do terms1 <- mapM (simplifyAndOr True) terms0 let terms' = if [LogCt True] `elem` terms1 then [[LogCt True]] else S.toList $ S.fromList $ filter (/= [LogCt False]) terms1 if null terms' then return [[LogCt False]] else do let len1terms = all ((1==) . length) terms' if not len1terms then return terms' else do let terms_flat = concat terms' terms'' <- simplifyAndOr False terms_flat return $ map (:[]) terms'' -- big helper function for simplifyDNF simplifyAndOr :: Bool -> [BTerm] -> AlgSimplifyM [BTerm] simplifyAndOr _ [] = badAlgSimplifyM "simplifyAndOr: not a boolean expression!" simplifyAndOr is_and fs = if LogCt (not is_and) `elem` fs -- False AND p == False then return [LogCt $ not is_and] -- (i) p AND p == p, (ii) True AND p == p else do let fs' = S.toList . S.fromList . filter (/=LogCt is_and) $ fs if null fs' then return [LogCt is_and] else do -- IF p1 => p2 THEN p1 AND p2 --> p1 fs''<- foldM (\l x-> do (addx, l') <- trimImplies is_and x l return $ if addx then x:l' else l' ) [] fs' -- IF p1 => not p2 THEN p1 AND p2 == False isF <- foldM (\b x -> if b then return b else do notx <- negateBTerm x impliesAny is_and x notx fs'' ) False fs'' return $ if not isF then fs'' else if is_and then [LogCt False] else [LogCt True ] where -- e1 => e2 ? impliesRel :: BTerm -> BTerm -> AlgSimplifyM Bool impliesRel (LogCt False) _ = return True impliesRel _ (LogCt True) = return True impliesRel (LogCt True) e = do let e' = e -- simplifyNRel e return $ e' == LogCt True impliesRel e (LogCt False) = do e' <- negateBTerm e -- simplifyNRel =<< negateBTerm e return $ e' == LogCt True impliesRel (NRelExp rel1 e1) (NRelExp rel2 e2) = do -- ToDo: implement implies relation! --not_aggr <- asks cheap let btp = typeOfNAlg e1 if btp /= typeOfNAlg e2 then return False else do one <- getPos1 btp e1' <- fromNumSofP e1 e2' <- fromNumSofP e2 case (rel1, rel2, btp) of (LTH0, LTH0, IntType _) -> do e2me1m1 <- toNumSofP =<< simplifyScal (SMinus e2' $ SPlus e1' $ Val one) diffrel <- simplifyNRel $ NRelExp LTH0 e2me1m1 return $ diffrel == LogCt True (_, _, IntType _) -> badAlgSimplifyM "impliesRel: LEQ0 for Int!" (_, _, _) -> badAlgSimplifyM "impliesRel: exp of illegal type!" impliesRel p1 p2 | p1 == p2 = return True | otherwise = return False -- trimImplies(true, x, l) performs: p1 AND p2 == p1 if p1 => p2, -- i.e., removes any p in l such that: -- (i) x => p orelse if -- (ii) p => x then indicates that p should not be added to the reuslt -- trimImplies(false, x, l) performs: p1 OR p2 == p2 if p1 => p2, -- i.e., removes any p from l such that: -- (i) p => x orelse if -- (ii) x => p then indicates that p should not be added to the result trimImplies :: Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm]) trimImplies _ _ [] = return (True, []) trimImplies and_case x (p:ps) = do succc <- impliesRel x p if succc then if and_case then trimImplies and_case x ps else return (False, p:ps) else do suc <- impliesRel p x if suc then if and_case then return (False, p:ps) else trimImplies and_case x ps else do (addx, newps) <- trimImplies and_case x ps return (addx, p:newps) -- impliesAny(true, x, notx, l) performs: -- x AND p == False if p => not x, where p in l, -- impliesAny(true, x, notx, l) performs: -- x OR p == True if not x => p -- BUT only when p != x, i.e., avoids comparing x with notx impliesAny :: Bool -> BTerm -> BTerm -> [BTerm] -> AlgSimplifyM Bool impliesAny _ _ _ [] = return False impliesAny and_case x notx (p:ps) | x == p = impliesAny and_case x notx ps | otherwise = do succ' <- if and_case then impliesRel p notx else impliesRel notx p if succ' then return True else impliesAny and_case x notx ps ------------------------------------------------ --- Syntax-Directed (Brainless) Translators --- --- scalExp <-> NNumExp --- --- and negating a scalar expression --- ------------------------------------------------ -- negates an already simplified scalar expression, -- presumably more efficient than negating and -- then simplifying it. negateSimplified :: ScalExp -> AlgSimplifyM ScalExp negateSimplified (SNeg e) = return e negateSimplified (SNot e) = return e negateSimplified (SAbs e) = return $ SAbs e negateSimplified (SSignum e) = SSignum <$> negateSimplified e negateSimplified e@(Val v) = do m1 <- getNeg1 $ scalExpType e v' <- mulVals m1 v; return $ Val v' negateSimplified e@Id{} = do m1 <- getNeg1 $ scalExpType e return $ STimes (Val m1) e negateSimplified (SMinus e1 e2) = do -- return $ SMinus e2 e1 e1' <- negateSimplified e1 return $ SPlus e1' e2 negateSimplified (SPlus e1 e2) = do e1' <- negateSimplified e1 e2' <- negateSimplified e2 return $ SPlus e1' e2' negateSimplified e@(SPow _ _) = do m1 <- getNeg1 $ scalExpType e return $ STimes (Val m1) e negateSimplified (STimes e1 e2) = do (e1', e2') <- helperNegateMult e1 e2; return $ STimes e1' e2' negateSimplified (SDiv e1 e2) = do (e1', e2') <- helperNegateMult e1 e2; return $ SDiv e1' e2' negateSimplified (SMod e1 e2) = return $ SMod e1 e2 negateSimplified (SQuot e1 e2) = do (e1', e2') <- helperNegateMult e1 e2; return $ SQuot e1' e2' negateSimplified (SRem e1 e2) = return $ SRem e1 e2 negateSimplified (MaxMin ismin ts) = MaxMin (not ismin) <$> mapM negateSimplified ts negateSimplified (RelExp LEQ0 e) = RelExp LTH0 <$> negateSimplified e negateSimplified (RelExp LTH0 e) = RelExp LEQ0 <$> negateSimplified e negateSimplified SLogAnd{} = badAlgSimplifyM "negateSimplified: SLogAnd unimplemented!" negateSimplified SLogOr{} = badAlgSimplifyM "negateSimplified: SLogOr unimplemented!" helperNegateMult :: ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp) helperNegateMult e1 e2 = case (e1, e2) of (Val _, _) -> do e1'<- negateSimplified e1; return (e1', e2) (STimes (Val v) e1r, _) -> do ev <- negateSimplified (Val v); return (STimes ev e1r, e2) (_, Val _) -> do e2'<- negateSimplified e2; return (e1, e2') (_, STimes (Val v) e2r) -> do ev <- negateSimplified (Val v); return (e1, STimes ev e2r) (_, _) -> do e1'<- negateSimplified e1; return (e1', e2) toNumSofP :: ScalExp -> AlgSimplifyM NNumExp toNumSofP e@(Val _) = return $ NProd [e] $ scalExpType e toNumSofP e@(Id _ _) = return $ NProd [e] $ scalExpType e toNumSofP e@SDiv{} = return $ NProd [e] $ scalExpType e toNumSofP e@SPow{} = return $ NProd [e] $ scalExpType e toNumSofP (SMinus _ _) = badAlgSimplifyM "toNumSofP: SMinus is not in SofP form!" toNumSofP (SNeg _) = badAlgSimplifyM "toNumSofP: SNeg is not in SofP form!" toNumSofP (STimes e1 e2) = do e2' <- toNumSofP e2 case e2' of NProd es2 t -> return $ NProd (e1:es2) t _ -> badAlgSimplifyM "toNumSofP: STimes nor in SofP form!" toNumSofP (SPlus e1 e2) = do let t = scalExpType e1 e1' <- toNumSofP e1 e2' <- toNumSofP e2 case (e1', e2') of (NSum es1 _, NSum es2 _) -> return $ NSum (es1++es2) t (NSum es1 _, NProd{}) -> return $ NSum (es1++[e2']) t (NProd{}, NSum es2 _) -> return $ NSum (e1':es2) t (NProd{}, NProd{} ) -> return $ NSum [e1', e2'] t toNumSofP me@MaxMin{} = return $ NProd [me] $ scalExpType me toNumSofP s_e = return $ NProd [s_e] $ scalExpType s_e fromNumSofP :: NNumExp -> AlgSimplifyM ScalExp fromNumSofP (NSum [ ] t) = Val <$> getZero t fromNumSofP (NSum [f] _) = fromNumSofP f fromNumSofP (NSum (f:fs) t) = do fs_e <- fromNumSofP $ NSum fs t f_e <- fromNumSofP f return $ SPlus f_e fs_e fromNumSofP (NProd [] _) = badAlgSimplifyM " In fromNumSofP, empty NProd expression! " fromNumSofP (NProd [f] _) = return f fromNumSofP (NProd (f:fs) t) = do fs_e <- fromNumSofP $ NProd fs t return $ STimes f fs_e --fromNumSofP _ = do -- pos <- asks pos -- badAlgSimplifyM "fromNumSofP: unimplemented!" ------------------------------------------------------------ --- Helpers for simplifyScal: getTerms, getMultChildren, --- --- splitTerm, joinTerm, discriminate ------------------------------------------------------------ -- get the list of terms of an expression -- BUG for NMinMax -> should convert it back to a ScalExp getTerms :: NNumExp -> [NNumExp] getTerms (NSum es _) = es getTerms e@NProd{} = [e] -- get the factors of a term getMultChildren :: NNumExp -> AlgSimplifyM [ScalExp] getMultChildren (NSum _ _) = badAlgSimplifyM "getMultChildren, NaryPlus should not be nested 2 levels deep " getMultChildren (NProd xs _) = return xs -- split a term into a (multiplicative) value and the rest of the factors. splitTerm :: NNumExp -> AlgSimplifyM (NNumExp, PrimValue) splitTerm (NProd [ ] _) = badAlgSimplifyM "splitTerm: Empty n-ary list of factors." splitTerm (NProd [f] tp) = do one <- getPos1 tp case f of (Val v) -> return (NProd [Val one] tp, v ) e -> return (NProd [e] tp, one) splitTerm ne@(NProd (f:fs) tp) = case f of (Val v) -> return (NProd fs tp, v) _ -> do one <- getPos1 tp return (ne, one) splitTerm e = do one <- getPos1 (typeOfNAlg e) return (e, one) -- join a value with a list of factors into a term. joinTerm :: (NNumExp, PrimValue) -> AlgSimplifyM NNumExp joinTerm ( NSum _ _, _) = badAlgSimplifyM "joinTerm: NaryPlus two levels deep." joinTerm ( NProd [] _, _) = badAlgSimplifyM "joinTerm: Empty NaryProd." joinTerm ( NProd (Val l:fs) tp, v) = do v' <- mulVals v l let v'Lit = Val v' return $ NProd (v'Lit:sort fs) tp joinTerm ( e@(NProd fs tp), v) | P.oneIsh v = return e | otherwise = let vExp = Val v in return $ NProd (vExp:sort fs) tp -- adds up the values corresponding to identical factors! discriminate :: [(NNumExp, PrimValue)] -> (NNumExp, PrimValue) -> AlgSimplifyM [(NNumExp, PrimValue)] discriminate [] e = return [e] discriminate e@((k,v):t) (k', v') = if k == k' then do v'' <- addVals v v' return ( (k, v'') : t ) else return ( (k', v') : e ) ------------------------------------------------------ --- Trivial Utility Functions --- ------------------------------------------------------ getZero :: PrimType -> AlgSimplifyM PrimValue getZero (IntType t) = return $ value $ intValue t (0::Int) getZero tp = badAlgSimplifyM ("getZero for type: "++pretty tp) getPos1 :: PrimType -> AlgSimplifyM PrimValue getPos1 (IntType t) = return $ value $ intValue t (1::Int) getPos1 tp = badAlgSimplifyM ("getOne for type: "++pretty tp) getNeg1 :: PrimType -> AlgSimplifyM PrimValue getNeg1 (IntType t) = return $ value $ intValue t (-1::Int) getNeg1 tp = badAlgSimplifyM ("getOne for type: "++pretty tp) valLTHEQ0 :: RelOp0 -> PrimValue -> AlgSimplifyM Bool valLTHEQ0 LEQ0 (IntValue iv) = return $ P.intToInt64 iv <= 0 valLTHEQ0 LTH0 (IntValue iv) = return $ P.intToInt64 iv < 0 valLTHEQ0 _ _ = badAlgSimplifyM "valLTHEQ0 for non-numeric type!" isCt1 :: ScalExp -> Bool isCt1 (Val v) = P.oneIsh v isCt1 _ = False isCt0 :: ScalExp -> Bool isCt0 (Val v) = P.zeroIsh v isCt0 _ = False addVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue addVals (IntValue v1) (IntValue v2) = return $ IntValue $ P.doAdd v1 v2 addVals _ _ = badAlgSimplifyM "addVals: operands not of (the same) numeral type! " mulVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue mulVals (IntValue v1) (IntValue v2) = return $ IntValue $ P.doMul v1 v2 mulVals v1 v2 = badAlgSimplifyM $ "mulVals: operands not of (the same) numeral type! "++ pretty v1++" "++pretty v2 divVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue divVals (IntValue v1) (IntValue v2) = case P.doSDiv v1 v2 of Just v -> return $ IntValue v Nothing -> badAlgSimplifyM "Division by zero" divVals _ _ = badAlgSimplifyM "divVals: operands not of (the same) numeral type! " canDivValsEvenly :: PrimValue -> PrimValue -> AlgSimplifyM Bool canDivValsEvenly (IntValue v1) (IntValue v2) = case P.doSMod v1 v2 of Just v -> return $ P.zeroIsh $ IntValue v Nothing -> return False canDivValsEvenly _ _ = badAlgSimplifyM "canDivValsEvenly: operands not of (the same) numeral type!" ------------------------------------------------------------- ------------------------------------------------------------- ---- Helpers for the ScalExp and NNumRelLogExp Datatypes ---- ------------------------------------------------------------- ------------------------------------------------------------- typeOfNAlg :: NNumExp -> PrimType typeOfNAlg (NSum _ t) = t typeOfNAlg (NProd _ t) = t ---------------------------------------- ---- Helpers for Division and Power ---- ---------------------------------------- trySimplifyDivRec :: [ScalExp] -> [ScalExp] -> [(NNumExp, PrimValue)] -> AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)]) trySimplifyDivRec [] fs' spl_terms = return (fs', spl_terms) trySimplifyDivRec (f:fs) fs' spl_terms = do res_tmp <- mapM (tryDivProdByOneFact f) spl_terms let (succs, spl_terms') = unzip res_tmp if all (==True) succs then trySimplifyDivRec fs fs' spl_terms' else trySimplifyDivRec fs (fs'++[f]) spl_terms tryDivProdByOneFact :: ScalExp -> (NNumExp, PrimValue) -> AlgSimplifyM (Bool, (NNumExp, PrimValue)) tryDivProdByOneFact (Val f) (e, v) = do succc <- canDivValsEvenly v f if succc then do vres <- divVals v f return (True, (e, vres)) else return (False,(e, v) ) tryDivProdByOneFact _ pev@(NProd [] _, _) = return (False, pev) tryDivProdByOneFact f pev@(NProd (t:tfs) tp, v) = do (succc, newt) <- tryDivTriv t f one <- getPos1 tp if not succc then do (succ', (tfs', v')) <- tryDivProdByOneFact f (NProd tfs tp, v) case (succ', tfs') of (True, NProd (Val vv:tfs'') _) -> do vres <- mulVals v' vv return (True, (NProd (t:tfs'') tp, vres)) (True, NProd tfs'' _) -> return (True, (NProd (t:tfs'') tp, v')) (_, _) -> return (False, pev) else case (newt, tfs) of (Val vv, _) -> do vres <- mulVals vv v return $ if null tfs then (True, (NProd [Val one] tp, vres)) else (True, (NProd tfs tp, vres)) (_, _) -> return (True, (NProd (newt:tfs) tp, v)) tryDivProdByOneFact _ (NSum _ _, _) = badAlgSimplifyM "tryDivProdByOneFact: unreachable case NSum reached!" tryDivTriv :: ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp) tryDivTriv (SPow a e1) (SPow d e2) | a == d && e1 == e2 = do one <- getPos1 $ scalExpType a return (True, Val one) | a == d = do let tp = scalExpType a one <- getPos1 tp e1me2 <- simplifyScal $ SMinus e1 e2 case (tp, e1me2) of (IntType _, Val v) | P.zeroIsh v -> return (True, Val one) (IntType _, Val v) | P.oneIsh v -> return (True, a) (IntType _, _) -> do e2me1 <- negateSimplified e1me2 e2me1_sop <- toNumSofP e2me1 p' <- simplifyNRel $ NRelExp LTH0 e2me1_sop return $ if p' == LogCt True then (True, SPow a e1me2) else (False, SDiv (SPow a e1) (SPow d e2)) (_, _) -> return (False, SDiv (SPow a e1) (SPow d e2)) | otherwise = return (False, SDiv (SPow a e1) (SPow d e2)) tryDivTriv (SPow a e1) b | a == b = do one <- getPos1 $ scalExpType a tryDivTriv (SPow a e1) (SPow a (Val one)) | otherwise = return (False, SDiv (SPow a e1) b) tryDivTriv b (SPow a e1) | a == b = do one <- getPos1 $ scalExpType a tryDivTriv (SPow a (Val one)) (SPow a e1) | otherwise = return (False, SDiv b (SPow a e1)) tryDivTriv t f | t == f = do one <- getPos1 $ scalExpType t return (True, Val one) | otherwise = return (False, SDiv t f) {- mkRelExp :: Int -> (RangesRep, ScalExp, ScalExp) mkRelExp 1 = let (i',j',n',p',m') = (tident "int i", tident "int j", tident "int n", tident "int p", tident "int m") (i,j,n,p,m) = (Id i', Id j', Id n', Id p', Id m') one = Val (IntVal 1) min_p_nm1 = MaxMin True [p, SMinus n one] hash = M.fromList $ [ (identName n', ( 1::Int, Just (Val (IntVal 1)), Nothing ) ), (identName p', ( 1::Int, Just (Val (IntVal 0)), Nothing ) ), (identName i', ( 5::Int, Just (Val (IntVal 0)), Just min_p_nm1 ) ) , (identName j', ( 9::Int, Just (Val (IntVal 0)), Just i ) ) ] -- M.Map VName (Int, Maybe ScalExp, Maybe ScalExp) ij_p_j_p_1_m_m = SMinus (SPlus (STimes i j) (SPlus j one)) m rel1 = RelExp LTH0 ij_p_j_p_1_m_m m_ij_m_j_m_2 = SNeg ( SPlus (STimes i j) (SPlus j (Val (IntVal 2))) ) rel2 = RelExp LTH0 m_ij_m_j_m_2 in (hash, rel1, rel2) mkRelExp 2 = let (i',a',b',l',u') = (tident "int i", tident "int a", tident "int b", tident "int l", tident "int u") (i,a,b,l,u) = (Id i', Id a', Id b', Id l', Id u') hash = M.fromList $ [ (identName i', ( 5::Int, Just l, Just u ) ) ] ai_p_b = SPlus (STimes a i) b rel1 = RelExp LTH0 ai_p_b in (hash, rel1, rel1) mkRelExp 3 = let (i',j',n',m') = (tident "int i", tident "int j", tident "int n", tident "int m") (i,j,n,m) = (Id i', Id j', Id n', Id m') one = Val (IntVal 1) two = Val (IntVal 2) min_j_nm1 = MaxMin True [MaxMin False [Val (IntVal 0), SMinus i (STimes two n)], SMinus n one] hash = M.fromList $ [ (identName n', ( 1::Int, Just (Val (IntVal 1)), Nothing ) ), (identName m', ( 2::Int, Just (Val (IntVal 1)), Nothing ) ), (identName i', ( 5::Int, Just (Val (IntVal 0)), Just (SMinus m one) ) ) , (identName j', ( 9::Int, Just (Val (IntVal 0)), Just min_j_nm1 ) ) ] -- M.Map VName (Int, Maybe ScalExp, Maybe ScalExp) ij_m_m = SMinus (STimes i j) m rel1 = RelExp LTH0 ij_m_m -- rel3 = RelExp LTH0 (SMinus i (SPlus (STimes two n) j)) m_ij_m_1 = SMinus (Val (IntVal (-1))) (STimes i j) rel2 = RelExp LTH0 m_ij_m_1 -- simpl_exp = SDiv (MaxMin True [SMinus (Val (IntVal 0)) (STimes i j), SNeg (STimes i n) ]) -- (STimes i j) -- rel4 = RelExp LTH0 simpl_exp in (hash, rel1, rel2) mkRelExp _ = let hash = M.empty rel = RelExp LTH0 (Val (IntVal (-1))) in (hash, rel, rel) -}