{-# LANGUAGE FlexibleInstances, LambdaCase #-}
module Futhark.Analysis.AlgSimplify
  ( ScalExp
  , Error
  , simplify
  , mkSuffConds
  , RangesRep
  , ppRangesRep
  , linFormScalE
  , pickSymToElim
  )
  where

import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.List (sort, sortBy, genericReplicate)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State

import Futhark.Representation.AST hiding (SDiv, SMod, SQuot, SRem, SSignum)
import Futhark.Analysis.ScalExp
import qualified Futhark.Representation.Primitive as P

-- | Ranges are inclusive.
type RangesRep = M.Map VName (Int, Maybe ScalExp, Maybe ScalExp)

-- | Prettyprint a 'RangesRep'.  Do not rely on the format of this
-- string.  Does not include the loop nesting depth information.
ppRangesRep :: RangesRep -> String
ppRangesRep :: RangesRep -> String
ppRangesRep = [String] -> String
unlines ([String] -> String)
-> (RangesRep -> [String]) -> RangesRep -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> [String]
forall a. Ord a => [a] -> [a]
sort ([String] -> [String])
-> (RangesRep -> [String]) -> RangesRep -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, (Int, Maybe ScalExp, Maybe ScalExp)) -> String)
-> [(VName, (Int, Maybe ScalExp, Maybe ScalExp))] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (VName, (Int, Maybe ScalExp, Maybe ScalExp)) -> String
forall a a a.
(Eq a, Pretty a, Pretty a) =>
(a, (a, Maybe a, Maybe a)) -> String
ppRange ([(VName, (Int, Maybe ScalExp, Maybe ScalExp))] -> [String])
-> (RangesRep -> [(VName, (Int, Maybe ScalExp, Maybe ScalExp))])
-> RangesRep
-> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RangesRep -> [(VName, (Int, Maybe ScalExp, Maybe ScalExp))]
forall k a. Map k a -> [(k, a)]
M.toList
  where ppRange :: (a, (a, Maybe a, Maybe a)) -> String
ppRange (a
name, (a
_, Maybe a
lower, Maybe a
upper)) =
          a -> String
forall a. Pretty a => a -> String
pretty a
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++
          if Maybe a
lower Maybe a -> Maybe a -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe a
upper
          then String
"== " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Maybe a -> String
forall a. Pretty a => Maybe a -> String
ppBound Maybe a
lower
          else String
"[" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Maybe a -> String
forall a. Pretty a => Maybe a -> String
ppBound Maybe a
lower String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++
               Maybe a -> String
forall a. Pretty a => Maybe a -> String
ppBound Maybe a
upper String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"]"
        ppBound :: Maybe a -> String
ppBound Maybe a
Nothing = String
"?"
        ppBound (Just a
se) = a -> String
forall a. Pretty a => a -> String
pretty a
se

-- | environment recording the position and
--   a list of variable-to-range bindings.
data AlgSimplifyEnv = AlgSimplifyEnv { AlgSimplifyEnv -> Bool
inSolveLTH0 :: Bool
                                     , AlgSimplifyEnv -> RangesRep
ranges :: RangesRep
                                     , AlgSimplifyEnv -> Int
maxSteps :: Int
                                     -- ^ The number of
                                     -- simplifications to do before
                                     -- bailing out, to avoid spending
                                     -- too much time.
                                     }

data Error = StepsExceeded | Error String

type AlgSimplifyM = StateT Int (ReaderT AlgSimplifyEnv (Either Error))

runAlgSimplifier :: Bool -> AlgSimplifyM a -> RangesRep -> Either Error a
runAlgSimplifier :: Bool -> AlgSimplifyM a -> RangesRep -> Either Error a
runAlgSimplifier Bool
s AlgSimplifyM a
x RangesRep
r = ReaderT AlgSimplifyEnv (Either Error) a
-> AlgSimplifyEnv -> Either Error a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (AlgSimplifyM a -> Int -> ReaderT AlgSimplifyEnv (Either Error) a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT AlgSimplifyM a
x Int
0) AlgSimplifyEnv
env
  where env :: AlgSimplifyEnv
env = AlgSimplifyEnv :: Bool -> RangesRep -> Int -> AlgSimplifyEnv
AlgSimplifyEnv { inSolveLTH0 :: Bool
inSolveLTH0 = Bool
s
                             , ranges :: RangesRep
ranges = RangesRep
r
                             , maxSteps :: Int
maxSteps = Int
100 -- heuristically chosen
                             }

step :: AlgSimplifyM ()
step :: AlgSimplifyM ()
step = do (Int -> Int) -> AlgSimplifyM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+)
          Bool
exceeded <- (Int -> Int -> Bool)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) (Int -> Bool)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(>) StateT Int (ReaderT AlgSimplifyEnv (Either Error)) (Int -> Bool)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Int
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (AlgSimplifyEnv -> Int)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Int
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AlgSimplifyEnv -> Int
maxSteps
          Bool -> AlgSimplifyM () -> AlgSimplifyM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
exceeded AlgSimplifyM ()
forall a. AlgSimplifyM a
stepsExceeded

stepsExceeded :: AlgSimplifyM a
stepsExceeded :: AlgSimplifyM a
stepsExceeded = ReaderT AlgSimplifyEnv (Either Error) a -> AlgSimplifyM a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT AlgSimplifyEnv (Either Error) a -> AlgSimplifyM a)
-> ReaderT AlgSimplifyEnv (Either Error) a -> AlgSimplifyM a
forall a b. (a -> b) -> a -> b
$ Either Error a -> ReaderT AlgSimplifyEnv (Either Error) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Either Error a -> ReaderT AlgSimplifyEnv (Either Error) a)
-> Either Error a -> ReaderT AlgSimplifyEnv (Either Error) a
forall a b. (a -> b) -> a -> b
$ Error -> Either Error a
forall a b. a -> Either a b
Left Error
StepsExceeded

badAlgSimplifyM :: String -> AlgSimplifyM a
badAlgSimplifyM :: String -> AlgSimplifyM a
badAlgSimplifyM = ReaderT AlgSimplifyEnv (Either Error) a -> AlgSimplifyM a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT AlgSimplifyEnv (Either Error) a -> AlgSimplifyM a)
-> (String -> ReaderT AlgSimplifyEnv (Either Error) a)
-> String
-> AlgSimplifyM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Either Error a -> ReaderT AlgSimplifyEnv (Either Error) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Either Error a -> ReaderT AlgSimplifyEnv (Either Error) a)
-> (String -> Either Error a)
-> String
-> ReaderT AlgSimplifyEnv (Either Error) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Error -> Either Error a
forall a b. a -> Either a b
Left (Error -> Either Error a)
-> (String -> Error) -> String -> Either Error a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Error
Error

-- | Binds an array name to the set of used-array vars
markInSolve :: AlgSimplifyEnv -> AlgSimplifyEnv
markInSolve :: AlgSimplifyEnv -> AlgSimplifyEnv
markInSolve AlgSimplifyEnv
env =
  AlgSimplifyEnv
env { inSolveLTH0 :: Bool
inSolveLTH0 = Bool
True }

markGaussLTH0 :: AlgSimplifyM a -> AlgSimplifyM a
markGaussLTH0 :: AlgSimplifyM a -> AlgSimplifyM a
markGaussLTH0 = (AlgSimplifyEnv -> AlgSimplifyEnv)
-> AlgSimplifyM a -> AlgSimplifyM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local AlgSimplifyEnv -> AlgSimplifyEnv
markInSolve

-----------------------------------------------------------
-- A Scalar Expression, i.e., ScalExp, is simplified to: --
--   1. if numeric: to a normalized sum-of-products form,--
--      in which on the outermost level there are N-ary  --
--      Min/Max nodes, and the next two levels are a sum --
--      of products.                                     --
--   2. if boolean: to disjunctive normal form           --
--                                                       --
-- Corresponding Helper Representations are:             --
--   1. NNumExp, i.e., NSum of NProd of ScalExp          --
--   2. DNF                                              --
-----------------------------------------------------------

data NNumExp = NSum   [NNumExp]  PrimType
             | NProd  [ScalExp]  PrimType
               deriving (NNumExp -> NNumExp -> Bool
(NNumExp -> NNumExp -> Bool)
-> (NNumExp -> NNumExp -> Bool) -> Eq NNumExp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NNumExp -> NNumExp -> Bool
$c/= :: NNumExp -> NNumExp -> Bool
== :: NNumExp -> NNumExp -> Bool
$c== :: NNumExp -> NNumExp -> Bool
Eq, Eq NNumExp
Eq NNumExp
-> (NNumExp -> NNumExp -> Ordering)
-> (NNumExp -> NNumExp -> Bool)
-> (NNumExp -> NNumExp -> Bool)
-> (NNumExp -> NNumExp -> Bool)
-> (NNumExp -> NNumExp -> Bool)
-> (NNumExp -> NNumExp -> NNumExp)
-> (NNumExp -> NNumExp -> NNumExp)
-> Ord NNumExp
NNumExp -> NNumExp -> Bool
NNumExp -> NNumExp -> Ordering
NNumExp -> NNumExp -> NNumExp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: NNumExp -> NNumExp -> NNumExp
$cmin :: NNumExp -> NNumExp -> NNumExp
max :: NNumExp -> NNumExp -> NNumExp
$cmax :: NNumExp -> NNumExp -> NNumExp
>= :: NNumExp -> NNumExp -> Bool
$c>= :: NNumExp -> NNumExp -> Bool
> :: NNumExp -> NNumExp -> Bool
$c> :: NNumExp -> NNumExp -> Bool
<= :: NNumExp -> NNumExp -> Bool
$c<= :: NNumExp -> NNumExp -> Bool
< :: NNumExp -> NNumExp -> Bool
$c< :: NNumExp -> NNumExp -> Bool
compare :: NNumExp -> NNumExp -> Ordering
$ccompare :: NNumExp -> NNumExp -> Ordering
$cp1Ord :: Eq NNumExp
Ord, Int -> NNumExp -> String -> String
[NNumExp] -> String -> String
NNumExp -> String
(Int -> NNumExp -> String -> String)
-> (NNumExp -> String)
-> ([NNumExp] -> String -> String)
-> Show NNumExp
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [NNumExp] -> String -> String
$cshowList :: [NNumExp] -> String -> String
show :: NNumExp -> String
$cshow :: NNumExp -> String
showsPrec :: Int -> NNumExp -> String -> String
$cshowsPrec :: Int -> NNumExp -> String -> String
Show)

data BTerm   = NRelExp RelOp0 NNumExp
             | LogCt  !Bool
             | PosId   VName
             | NegId   VName
               deriving (BTerm -> BTerm -> Bool
(BTerm -> BTerm -> Bool) -> (BTerm -> BTerm -> Bool) -> Eq BTerm
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: BTerm -> BTerm -> Bool
$c/= :: BTerm -> BTerm -> Bool
== :: BTerm -> BTerm -> Bool
$c== :: BTerm -> BTerm -> Bool
Eq, Eq BTerm
Eq BTerm
-> (BTerm -> BTerm -> Ordering)
-> (BTerm -> BTerm -> Bool)
-> (BTerm -> BTerm -> Bool)
-> (BTerm -> BTerm -> Bool)
-> (BTerm -> BTerm -> Bool)
-> (BTerm -> BTerm -> BTerm)
-> (BTerm -> BTerm -> BTerm)
-> Ord BTerm
BTerm -> BTerm -> Bool
BTerm -> BTerm -> Ordering
BTerm -> BTerm -> BTerm
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: BTerm -> BTerm -> BTerm
$cmin :: BTerm -> BTerm -> BTerm
max :: BTerm -> BTerm -> BTerm
$cmax :: BTerm -> BTerm -> BTerm
>= :: BTerm -> BTerm -> Bool
$c>= :: BTerm -> BTerm -> Bool
> :: BTerm -> BTerm -> Bool
$c> :: BTerm -> BTerm -> Bool
<= :: BTerm -> BTerm -> Bool
$c<= :: BTerm -> BTerm -> Bool
< :: BTerm -> BTerm -> Bool
$c< :: BTerm -> BTerm -> Bool
compare :: BTerm -> BTerm -> Ordering
$ccompare :: BTerm -> BTerm -> Ordering
$cp1Ord :: Eq BTerm
Ord, Int -> BTerm -> String -> String
[BTerm] -> String -> String
BTerm -> String
(Int -> BTerm -> String -> String)
-> (BTerm -> String) -> ([BTerm] -> String -> String) -> Show BTerm
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [BTerm] -> String -> String
$cshowList :: [BTerm] -> String -> String
show :: BTerm -> String
$cshow :: BTerm -> String
showsPrec :: Int -> BTerm -> String -> String
$cshowsPrec :: Int -> BTerm -> String -> String
Show)
type NAnd    = [BTerm]
type DNF     = [NAnd ]
--type NOr     = [BTerm]
--type CNF     = [NOr  ]

-- | Applies Simplification at Expression level:
simplify :: ScalExp -> RangesRep -> ScalExp
simplify :: ScalExp -> RangesRep -> ScalExp
simplify ScalExp
e RangesRep
rangesrep = case Bool -> AlgSimplifyM ScalExp -> RangesRep -> Either Error ScalExp
forall a. Bool -> AlgSimplifyM a -> RangesRep -> Either Error a
runAlgSimplifier Bool
False (ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e) RangesRep
rangesrep of
  Left (Error String
err) ->
    String -> ScalExp
forall a. HasCallStack => String -> a
error (String -> ScalExp) -> String -> ScalExp
forall a b. (a -> b) -> a -> b
$ String
"Error during algebraic simplification of: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ScalExp -> String
forall a. Pretty a => a -> String
pretty ScalExp
e String -> String -> String
forall a. [a] -> [a] -> [a]
++
    String
"\n"  String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
err
  Left Error
StepsExceeded -> ScalExp
e
  Right ScalExp
e' -> ScalExp
e'

-- | Given a symbol i and a scalar expression e, it decomposes
--   e = a*i + b and returns (a,b) if possible, otherwise Nothing.
linFormScalE :: VName -> ScalExp -> RangesRep -> Either Error (Maybe (ScalExp,ScalExp))
linFormScalE :: VName
-> ScalExp -> RangesRep -> Either Error (Maybe (ScalExp, ScalExp))
linFormScalE VName
i ScalExp
e = Bool
-> AlgSimplifyM (Maybe (ScalExp, ScalExp))
-> RangesRep
-> Either Error (Maybe (ScalExp, ScalExp))
forall a. Bool -> AlgSimplifyM a -> RangesRep -> Either Error a
runAlgSimplifier Bool
False (VName -> ScalExp -> AlgSimplifyM (Maybe (ScalExp, ScalExp))
linearFormScalExp VName
i ScalExp
e)

-- | Extracts sufficient conditions for a LTH0 relation to hold
mkSuffConds :: ScalExp -> RangesRep -> Either Error [[ScalExp]]
mkSuffConds :: ScalExp -> RangesRep -> Either Error [[ScalExp]]
mkSuffConds ScalExp
e = Bool
-> AlgSimplifyM [[ScalExp]]
-> RangesRep
-> Either Error [[ScalExp]]
forall a. Bool -> AlgSimplifyM a -> RangesRep -> Either Error a
runAlgSimplifier Bool
True (ScalExp -> AlgSimplifyM [[ScalExp]]
gaussElimRel ScalExp
e)

{-
-- | Test if Simplification engine can handle this kind of expression
canSimplify :: Int -> Either Error ScalExp --[[ScalExp]]
canSimplify i = do
    let (h,_,e2) = mkRelExp i
    case e2 of
        (RelExp LTH0 _) -> do
              -- let e1' = trace (pretty e1) e1
              simplify e2 noLoc h
--            runAlgSimplifier False (gaussAllLTH0 False S.empty =<< toNumSofP =<< simplifyScal e) noLoc h
        _ -> simplify e2 noLoc h-- badAlgSimplifyM "canSimplify: unimplemented!"
-}
-------------------------------------------------------
--- Assumes the relational expression is simplified  --
--- All uses of gaussiam elimination from simplify   --
---  must use simplifyNRel, which calls markGaussLTH0--
---  to set the inSolveLTH0 environment var, so that --
---  we do not enter an infinite recurssion!         --
--- Returns True or False or the input replation,i.e.--
---    static-only simplification!                   --
-------------------------------------------------------
simplifyNRel :: BTerm -> AlgSimplifyM BTerm
simplifyNRel :: BTerm -> AlgSimplifyM BTerm
simplifyNRel inp_term :: BTerm
inp_term@(NRelExp RelOp0
LTH0 NNumExp
inp_sofp) = do
    BTerm
term <- BTerm -> AlgSimplifyM BTerm
cheapSimplifyNRel BTerm
inp_term
    Bool
in_gauss <- (AlgSimplifyEnv -> Bool)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AlgSimplifyEnv -> Bool
inSolveLTH0
    let tp :: PrimType
tp = NNumExp -> PrimType
typeOfNAlg NNumExp
inp_sofp

    if Bool
in_gauss Bool -> Bool -> Bool
|| BTerm -> Bool
isTrivialNRel BTerm
term Bool -> Bool -> Bool
|| PrimType
tp PrimType -> [PrimType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (IntType -> PrimType) -> [IntType] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> PrimType
IntType [IntType]
allIntTypes
    then BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return BTerm
term
    else do ScalExp
ednf <- AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall a. AlgSimplifyM a -> AlgSimplifyM a
markGaussLTH0 (AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
True Set VName
forall a. Set a
S.empty NNumExp
inp_sofp
            BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ case ScalExp
ednf of
              Val (BoolValue Bool
c) -> Bool -> BTerm
LogCt Bool
c
              ScalExp
_              -> BTerm
term
    where
        isTrivialNRel :: BTerm -> Bool
isTrivialNRel (NRelExp RelOp0
_ (NProd [Val PrimValue
_] PrimType
_)) = Bool
True
        isTrivialNRel NRelExp{}                     = Bool
False
        isTrivialNRel  BTerm
_                            = Bool
False

        cheapSimplifyNRel :: BTerm -> AlgSimplifyM BTerm
        cheapSimplifyNRel :: BTerm -> AlgSimplifyM BTerm
cheapSimplifyNRel (NRelExp RelOp0
rel (NProd [Val PrimValue
v] PrimType
_)) =
            Bool -> BTerm
LogCt (Bool -> BTerm)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
-> AlgSimplifyM BTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RelOp0
-> PrimValue
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
valLTHEQ0 RelOp0
rel PrimValue
v
        cheapSimplifyNRel BTerm
e = BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return BTerm
e
simplifyNRel BTerm
inp_term =
    BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return BTerm
inp_term --- TODO: handle more cases.

--gaussEliminateNRel :: BTerm -> AlgSimplifyM DNF
--gaussEliminateNRel _ =
--    badAlgSimplifyM "gaussElimNRel: unimplemented!"


gaussElimRel :: ScalExp -> AlgSimplifyM [[ScalExp]] -- ScalExp
gaussElimRel :: ScalExp -> AlgSimplifyM [[ScalExp]]
gaussElimRel (RelExp RelOp0
LTH0 ScalExp
e) = do
    NNumExp
e_sofp <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e
    ScalExp
e_scal<- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
False Set VName
forall a. Set a
S.empty NNumExp
e_sofp
    DNF
e_dnf <- ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e_scal
    ([BTerm]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp])
-> DNF -> AlgSimplifyM [[ScalExp]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((BTerm -> AlgSimplifyM ScalExp)
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\case
                    LogCt Bool
c   -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
c)
                    PosId VName
i   -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
Id  VName
i (PrimType -> ScalExp) -> PrimType -> ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
                    NegId VName
i   -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
Id  VName
i (PrimType -> ScalExp) -> PrimType -> ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
                    NRelExp RelOp0
rel NNumExp
ee -> RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
rel (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
ee
               )) DNF
e_dnf

gaussElimRel ScalExp
_ =
    String -> AlgSimplifyM [[ScalExp]]
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"gaussElimRel: only LTH0 Int relations please!"

--ppSyms :: S.Set VName -> String
--ppSyms ss = foldl (\s x -> s ++ " " ++ (baseString x)) "ElimSyms: " (S.toList ss)


primScalExpLTH0 :: ScalExp -> Bool
primScalExpLTH0 :: ScalExp -> Bool
primScalExpLTH0 (Val (IntValue IntValue
v)) = IntValue -> Int64
P.intToInt64 IntValue
v Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0
primScalExpLTH0 ScalExp
_ = Bool
False
-----------------------------------------------------------
-----------------------------------------------------------
-----------------------------------------------------------
---`gaussAllLTH'                                        ---
---  `static_only':whether only a True/False answer is  ---
---                 required or actual a sufficient cond---
---     `el_syms': the list of already eliminated       ---
---                 symbols, initialy empty             ---
---     `sofp':    the expression e in sum-of-product   ---
---                 form that is compared to 0,         ---
---                 i.e., e < 0. sofp assumed simplified---
---     Result:    is a ScalExp expression, which is    ---
---                 actually a predicate in DNF form,   ---
---                 that is a sufficient condition      ---
---                 for e < 0!                          ---
---                                                     ---
--- gaussAllLTH0 is implementing the tracking of Min/Max---
---              terms, and uses `gaussOneDefaultLTH0'  ---
---              to implement gaussian-like elimination ---
---              to solve the a*i + b < 0 problem.      ---
---                                                     ---
--- IMPORTANT: before calling gaussAllLTH0 from outside ---
---            make sure to set insideSolveLTH0 env     ---
---            member to True, via markGaussLTH0;       ---
---            otherwise infinite recursion might happen---
---            w.r.t. `simplifyScal'                    ---
-----------------------------------------------------------
-----------------------------------------------------------
-----------------------------------------------------------
type Prod = [ScalExp]
gaussAllLTH0 :: Bool -> S.Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 :: Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only Set VName
el_syms NNumExp
sofp = do
    AlgSimplifyM ()
step
    let tp :: PrimType
tp  = NNumExp -> PrimType
typeOfNAlg NNumExp
sofp
    RangesRep
rangesrep <- (AlgSimplifyEnv -> RangesRep)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) RangesRep
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AlgSimplifyEnv -> RangesRep
ranges
    ScalExp
e_scal <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
sofp
    let mi :: Maybe VName
mi  = RangesRep -> Set VName -> ScalExp -> Maybe VName
pickSymToElim RangesRep
rangesrep Set VName
el_syms ScalExp
e_scal

    case Maybe VName
mi of
      Maybe VName
Nothing -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ if ScalExp -> Bool
primScalExpLTH0 ScalExp
e_scal
                          then PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                          else RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LTH0 ScalExp
e_scal
      Just VName
i  -> do
        (Maybe ScalExp
jmm, [ScalExp]
fs0, [NNumExp]
terms) <- VName
-> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
findMinMaxTerm VName
i NNumExp
sofp
        -- i.e., sofp == fs0 * jmm + terms, where
        --       i appears in jmm and jmm = MinMax ...

        [ScalExp]
fs <- if Bool -> Bool
not ([ScalExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ScalExp]
fs0) then [ScalExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [ScalExp]
fs0
              else do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 PrimType
tp; [ScalExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [PrimValue -> ScalExp
Val PrimValue
one]

        case Maybe ScalExp
jmm of
          ------------------------------------------------------------------------
          -- A MinMax expression which uses to-be-eliminated symbol i was found --
          ------------------------------------------------------------------------
          Just (MaxMin Bool
_     []  ) ->
                String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"gaussAllLTH0: Empty MinMax Node!"
          Just (MaxMin Bool
ismin [ScalExp]
mmts) -> do
            PrimValue
mone <- PrimType -> AlgSimplifyM PrimValue
getNeg1 PrimType
tp

            -- fs_lth0 => fs < 0
--            fs_lth0 <- if null fs then return $ Val (BoolValue False)
--                       else gaussAllLTH0 static_only el_syms (NProd fs tp)
            NNumExp
fsm1    <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< NNumExp -> AlgSimplifyM ScalExp
fromNumSofP
                         ( [NNumExp] -> PrimType -> NNumExp
NSum [[ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
fs PrimType
tp, [ScalExp] -> PrimType -> NNumExp
NProd [PrimValue -> ScalExp
Val PrimValue
mone] PrimType
tp] PrimType
tp )
            ScalExp
fs_leq0 <- Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only Set VName
el_syms NNumExp
fsm1  -- fs <= 0
            -- mfsm1 = - fs - 1, fs_geq0 => (fs >= 0),
            --             i.e., fs_geq0 => (-fs - 1 < 0)
            NNumExp
mfsm1   <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< NNumExp -> AlgSimplifyM ScalExp
fromNumSofP
                         ( [NNumExp] -> PrimType -> NNumExp
NSum [[ScalExp] -> PrimType -> NNumExp
NProd (PrimValue -> ScalExp
Val PrimValue
moneScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
fs) PrimType
tp,[ScalExp] -> PrimType -> NNumExp
NProd [PrimValue -> ScalExp
Val PrimValue
mone] PrimType
tp] PrimType
tp )
            ScalExp
fs_geq0 <- Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only Set VName
el_syms NNumExp
mfsm1

            -- mm_terms are the simplified terms of the MinMax obtained
            -- after inlining everything inside the MinMax, i.e., intially
            -- terms + fs * MinMax ismin [t1,..,tn] -> [fs*t1+terms, ..., fs*tn+terms]
            [NNumExp]
mm_terms<- (ScalExp -> AlgSimplifyM NNumExp)
-> [ScalExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\ScalExp
t -> ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< NNumExp -> AlgSimplifyM ScalExp
fromNumSofP
                                   ([NNumExp] -> PrimType -> NNumExp
NSum ( [ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
tScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
fs) PrimType
tpNNumExp -> [NNumExp] -> [NNumExp]
forall a. a -> [a] -> [a]
:[NNumExp]
terms ) PrimType
tp) ) [ScalExp]
mmts

            -- for every (simplified) `term_i' of the inline MinMax exp,
            --  get the sufficient conditions for `term_i < 0'
            [ScalExp]
mms     <- (NNumExp -> AlgSimplifyM ScalExp)
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only Set VName
el_syms) [NNumExp]
mm_terms

            if Bool
static_only
            --------------------------------------------------------------------
            -- returns either Val (BoolValue True) or the original ScalExp relat --
            --------------------------------------------------------------------
            then if ( ScalExp
fs_geq0 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True) Bool -> Bool -> Bool
&&     Bool
ismin) Bool -> Bool -> Bool
||
                    ( ScalExp
fs_leq0 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True) Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
ismin)
                 -- at least one term should be < 0!
                 then do let  is_one_true :: Bool
is_one_true  = PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True ) ScalExp -> [ScalExp] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ScalExp]
mms
                         let are_all_false :: Bool
are_all_false = (ScalExp -> Bool) -> [ScalExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all  (ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False)) [ScalExp]
mms
                         ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ if       Bool
is_one_true  then PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                                  else if Bool
are_all_false then PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False)
                                  else RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LTH0 ScalExp
e_scal
                 -- otherwise all terms should be all true!
                 else do let are_all_true :: Bool
are_all_true = (ScalExp -> Bool) -> [ScalExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all  (ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True )) [ScalExp]
mms
                         let is_one_false :: Bool
is_one_false = PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False) ScalExp -> [ScalExp] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ScalExp]
mms
                         ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ if      Bool
are_all_true then PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True )
                                  else if Bool
is_one_false then PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False)
                                  else RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LTH0 ScalExp
e_scal
            --------------------------------------------------------------------
            -- returns sufficient conditions for the ScalExp relation to hold --
            --------------------------------------------------------------------
            else do
                let mm_fsgeq0 :: ScalExp
mm_fsgeq0 = (ScalExp -> ScalExp -> ScalExp) -> ScalExp -> [ScalExp] -> ScalExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (if Bool
ismin then ScalExp -> ScalExp -> ScalExp
SLogOr else ScalExp -> ScalExp -> ScalExp
SLogAnd)
                                      (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue (Bool -> Bool
not Bool
ismin))) [ScalExp]
mms
                let mm_fslth0 :: ScalExp
mm_fslth0 = (ScalExp -> ScalExp -> ScalExp) -> ScalExp -> [ScalExp] -> ScalExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (if Bool
ismin then ScalExp -> ScalExp -> ScalExp
SLogAnd else ScalExp -> ScalExp -> ScalExp
SLogOr)
                                      (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue      Bool
ismin )) [ScalExp]
mms
                -- the sufficient condition for the original expression, e.g.,
                -- terms + fs * Min [t1,..,tn] < 0 is
                -- (fs >= 0 && (fs*t_1+terms < 0 || ... || fs*t_n+terms < 0) ) ||
                -- (fs <  0 && (fs*t_1+terms < 0 && ... && fs*t_n+terms < 0) )
                ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SLogOr (ScalExp -> ScalExp -> ScalExp
SLogAnd ScalExp
fs_geq0 ScalExp
mm_fsgeq0) (ScalExp -> ScalExp -> ScalExp
SLogAnd ScalExp
fs_leq0 ScalExp
mm_fslth0)

          Just ScalExp
_ -> String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"gaussOneLTH0: (Just MinMax) invariant violated!"
          ------------------------------------------------------------------------
          -- A MinMax expression which uses (to-be-elim) symbol i was NOT found --
          ------------------------------------------------------------------------
          Maybe ScalExp
Nothing-> do
            Maybe ScalExp
m_sofp <- Bool
-> VName -> Set VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp)
gaussOneDefaultLTH0 Bool
static_only VName
i Set VName
el_syms NNumExp
sofp
            case Maybe ScalExp
m_sofp of
                Maybe ScalExp
Nothing -> Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only (VName -> Set VName -> Set VName
forall a. Ord a => a -> Set a -> Set a
S.insert VName
i Set VName
el_syms) NNumExp
sofp
                Just ScalExp
res_eofp -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
res_eofp
    where
        findMinMaxTerm :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, Prod, [NNumExp])
        findMinMaxTerm :: VName
-> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
findMinMaxTerm VName
_  (NSum  [] PrimType
_) = (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
forall a. Maybe a
Nothing, [], [])
        findMinMaxTerm VName
_  (NSum  [NProd [MaxMin Bool
ismin [ScalExp]
e] PrimType
_] PrimType
_) =
            (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin [ScalExp]
e), [], [])
        findMinMaxTerm VName
_  (NProd [MaxMin Bool
ismin [ScalExp]
e] PrimType
_) =
            (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin [ScalExp]
e), [], [])

        findMinMaxTerm VName
ii t :: NNumExp
t@NProd{} = do (Maybe ScalExp
mm, [ScalExp]
fs) <- VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
findMinMaxFact VName
ii NNumExp
t
                                         (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
mm, [ScalExp]
fs, [])
        findMinMaxTerm VName
ii (NSum (NNumExp
t:[NNumExp]
ts) PrimType
tp)= do
            RangesRep
rangesrep <- (AlgSimplifyEnv -> RangesRep)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) RangesRep
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AlgSimplifyEnv -> RangesRep
ranges
            case VName -> RangesRep -> Maybe (Int, Maybe ScalExp, Maybe ScalExp)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
ii RangesRep
rangesrep of
                Just (Int
_, Just ScalExp
_, Just ScalExp
_) -> do
                    (Maybe ScalExp, [ScalExp])
f <- VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
findMinMaxFact VName
ii NNumExp
t
                    case (Maybe ScalExp, [ScalExp])
f of
                        (Just ScalExp
mm, [ScalExp]
fs) -> (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just ScalExp
mm, [ScalExp]
fs, [NNumExp]
ts)
                        (Maybe ScalExp
Nothing, [ScalExp]
_ ) -> do (Maybe ScalExp
mm, [ScalExp]
fs', [NNumExp]
ts') <- VName
-> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
findMinMaxTerm VName
ii ([NNumExp] -> PrimType -> NNumExp
NSum [NNumExp]
ts PrimType
tp)
                                            (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
mm, [ScalExp]
fs', NNumExp
tNNumExp -> [NNumExp] -> [NNumExp]
forall a. a -> [a] -> [a]
:[NNumExp]
ts')
                Maybe (Int, Maybe ScalExp, Maybe ScalExp)
_ -> (Maybe ScalExp, [ScalExp], [NNumExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp], [NNumExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
forall a. Maybe a
Nothing, [], NNumExp
tNNumExp -> [NNumExp] -> [NNumExp]
forall a. a -> [a] -> [a]
:[NNumExp]
ts)

        findMinMaxFact :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, Prod)
        findMinMaxFact :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
findMinMaxFact VName
_  (NProd []     PrimType
_ ) = (Maybe ScalExp, [ScalExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
forall a. Maybe a
Nothing, [])
        findMinMaxFact VName
ii (NProd (ScalExp
f:[ScalExp]
fs) PrimType
tp) =
            case ScalExp
f of
                MaxMin Bool
ismin [ScalExp]
ts -> do
                        let id_set :: Names
id_set = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (ScalExp -> Names) -> [ScalExp] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ScalExp -> Names
forall a. FreeIn a => a -> Names
freeIn [ScalExp]
ts
                        if VName
ii VName -> Names -> Bool
`nameIn` Names
id_set
                        then (Maybe ScalExp, [ScalExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin [ScalExp]
ts), [ScalExp]
fs)
                        else do (Maybe ScalExp
mm, [ScalExp]
fs') <- VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
findMinMaxFact VName
ii ([ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
fs PrimType
tp)
                                (Maybe ScalExp, [ScalExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
mm, ScalExp
fScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
fs')

                ScalExp
_ -> do (Maybe ScalExp
mm, [ScalExp]
fs') <- VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
findMinMaxFact VName
ii ([ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
fs PrimType
tp)
                        (Maybe ScalExp, [ScalExp])
-> AlgSimplifyM (Maybe ScalExp, [ScalExp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp
mm, ScalExp
fScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
fs')
        findMinMaxFact VName
ii (NSum [NNumExp
f] PrimType
_) = VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
findMinMaxFact VName
ii NNumExp
f
        findMinMaxFact VName
_  (NSum [NNumExp]
_ PrimType
_) =
            String -> AlgSimplifyM (Maybe ScalExp, [ScalExp])
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"findMinMaxFact: NSum argument illegal!"



gaussOneDefaultLTH0 :: Bool -> VName -> S.Set VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp)
gaussOneDefaultLTH0 :: Bool
-> VName -> Set VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp)
gaussOneDefaultLTH0  Bool
static_only VName
i Set VName
elsyms NNumExp
e = do
    Maybe (NNumExp, NNumExp)
aipb <- VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
linearForm VName
i NNumExp
e
    case Maybe (NNumExp, NNumExp)
aipb of
        Maybe (NNumExp, NNumExp)
Nothing     -> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ScalExp
forall a. Maybe a
Nothing
        Just (NNumExp
a, NNumExp
b) -> do
            RangesRep
rangesrep <- (AlgSimplifyEnv -> RangesRep)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) RangesRep
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks AlgSimplifyEnv -> RangesRep
ranges
            PrimValue
one    <- PrimType -> AlgSimplifyM PrimValue
getPos1 (NNumExp -> PrimType
typeOfNAlg NNumExp
e)
            ScalExp
ascal  <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
a
            NNumExp
mam1   <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp
SNeg (ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
ascal (PrimValue -> ScalExp
Val PrimValue
one)))
            NNumExp
am1    <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp -> ScalExp
SMinus ScalExp
ascal (PrimValue -> ScalExp
Val PrimValue
one))
            NNumExp
ma     <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp
SNeg ScalExp
ascal)

            ScalExp
b_scal<- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
b
            NNumExp
mbm1  <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp
SNeg (ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
b_scal (PrimValue -> ScalExp
Val PrimValue
one)))

            ScalExp
aleq0 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only Set VName
elsyms NNumExp
am1
            ScalExp
ageq0 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
static_only Set VName
elsyms NNumExp
mam1

            case VName -> RangesRep -> Maybe (Int, Maybe ScalExp, Maybe ScalExp)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
i RangesRep
rangesrep of
                Maybe (Int, Maybe ScalExp, Maybe ScalExp)
Nothing ->
                    String -> AlgSimplifyM (Maybe ScalExp)
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"gaussOneDefaultLTH0: sym not in ranges!"
                Just (Int
_, Maybe ScalExp
Nothing, Maybe ScalExp
Nothing) ->
                    String -> AlgSimplifyM (Maybe ScalExp)
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"gaussOneDefaultLTH0: both bounds are undefined!"

                -- only the lower-bound of i is known!
                Just (Int
_, Just ScalExp
lb, Maybe ScalExp
Nothing) -> do
                    ScalExp
alpblth0 <- Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
lb NNumExp
a NNumExp
b
                    ScalExp
and_half <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
alpblth0
                    case (ScalExp
and_half, ScalExp
aleq0) of
                        (Val (BoolValue Bool
True), Val (BoolValue Bool
True)) ->
                                Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just ScalExp
and_half
                        (ScalExp, ScalExp)
_ -> do ScalExp
malmbm1lth0 <- Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
lb NNumExp
ma NNumExp
mbm1
                                ScalExp
other_half  <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
malmbm1lth0
                                case (ScalExp
other_half, ScalExp
ageq0) of
                                    (Val (BoolValue Bool
True), Val (BoolValue Bool
True)) ->
                                            Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False))
                                    (ScalExp, ScalExp)
_  ->   Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ScalExp
forall a. Maybe a
Nothing

                Just (Int
_, Maybe ScalExp
Nothing, Just ScalExp
ub) -> do
                    ScalExp
aupblth0 <- Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
ub NNumExp
a NNumExp
b
                    ScalExp
and_half <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
aupblth0
                    case (ScalExp
and_half, ScalExp
ageq0) of
                        (Val (BoolValue Bool
True), Val (BoolValue Bool
True)) ->
                                Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just ScalExp
and_half
                        (ScalExp, ScalExp)
_ -> do
                                ScalExp
maumbm1    <- Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
ub NNumExp
ma NNumExp
mbm1
                                ScalExp
other_half <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
maumbm1
                                case (ScalExp
other_half, ScalExp
aleq0) of
                                    (Val (BoolValue Bool
True), Val (BoolValue Bool
True)) ->
                                            Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False))
                                    (ScalExp, ScalExp)
_  ->   Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ScalExp
forall a. Maybe a
Nothing

                Just (Int
_, Just ScalExp
lb, Just ScalExp
ub) ->
                    if Bool
static_only
                    then if ScalExp
aleq0 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                         then do ScalExp
alpblth0 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
lb NNumExp
a NNumExp
b
                                 if ScalExp
alpblth0 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                                   then Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True))
                                   else do ScalExp
maubmbm1 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
ub NNumExp
ma NNumExp
mbm1
                                           Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ if ScalExp
maubmbm1 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                                                    then ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False))
                                                    else Maybe ScalExp
forall a. Maybe a
Nothing
                      else if ScalExp
ageq0 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                      then do ScalExp
aupblth0 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
ub NNumExp
a NNumExp
b
                              if ScalExp
aupblth0 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                              then Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True))
                              else do ScalExp
malbmbm1 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
lb NNumExp
ma NNumExp
mbm1
                                      Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ if ScalExp
malbmbm1 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
True)
                                               then ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (PrimValue -> ScalExp
Val (Bool -> PrimValue
BoolValue Bool
False))
                                               else Maybe ScalExp
forall a. Maybe a
Nothing
                      else Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ScalExp
forall a. Maybe a
Nothing
                    else do
                      ScalExp
alpblth0 <- Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
lb NNumExp
a NNumExp
b
                      ScalExp
aupblth0 <- Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
static_only Set VName
elsyms ScalExp
ub NNumExp
a NNumExp
b
                      ScalExp
res <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SLogOr (ScalExp -> ScalExp -> ScalExp
SLogAnd ScalExp
aleq0 ScalExp
alpblth0) (ScalExp -> ScalExp -> ScalExp
SLogAnd ScalExp
ageq0 ScalExp
aupblth0)
                      Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp))
-> Maybe ScalExp -> AlgSimplifyM (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just ScalExp
res

    where
        gaussElimHalf :: Bool -> S.Set VName -> ScalExp -> NNumExp -> NNumExp -> AlgSimplifyM ScalExp
        gaussElimHalf :: Bool
-> Set VName
-> ScalExp
-> NNumExp
-> NNumExp
-> AlgSimplifyM ScalExp
gaussElimHalf Bool
only_static Set VName
elsyms0 ScalExp
q NNumExp
a NNumExp
b = do
            ScalExp
a_scal <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
a
            ScalExp
b_scal <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
b
            ScalExp
e_num_scal <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp -> ScalExp
SPlus (ScalExp -> ScalExp -> ScalExp
STimes ScalExp
a_scal ScalExp
q) ScalExp
b_scal)
            NNumExp
e_num <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e_num_scal
            Bool -> Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 Bool
only_static Set VName
elsyms0 NNumExp
e_num

--    pos <- asks pos
--    badAlgSimplifyM "gaussOneDefaultLTH0: unimplemented!"

----------------------------------------------------------
--- Pick a Symbol to Eliminate & Bring To Linear Form  ---
----------------------------------------------------------

pickSymToElim :: RangesRep -> S.Set VName -> ScalExp -> Maybe VName
pickSymToElim :: RangesRep -> Set VName -> ScalExp -> Maybe VName
pickSymToElim RangesRep
rangesrep Set VName
elsyms0 ScalExp
e_scal =
--    ranges <- asks ranges
--    e_scal <- fromNumSofP e0
    let ids0 :: [VName]
ids0= Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ ScalExp -> Names
forall a. FreeIn a => a -> Names
freeIn ScalExp
e_scal
        ids1 :: [VName]
ids1= (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (\VName
s -> Bool -> Bool
not (VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
s Set VName
elsyms0)) [VName]
ids0
        ids2 :: [VName]
ids2= (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (\VName
s -> case VName -> RangesRep -> Maybe (Int, Maybe ScalExp, Maybe ScalExp)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
s RangesRep
rangesrep of
                                Maybe (Int, Maybe ScalExp, Maybe ScalExp)
Nothing -> Bool
False
                                Just (Int, Maybe ScalExp, Maybe ScalExp)
_  -> Bool
True
                     ) [VName]
ids1
        ids :: [VName]
ids = (VName -> VName -> Ordering) -> [VName] -> [VName]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (\VName
n1 VName
n2 -> let n1p :: Maybe (Int, Maybe ScalExp, Maybe ScalExp)
n1p = VName -> RangesRep -> Maybe (Int, Maybe ScalExp, Maybe ScalExp)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n1 RangesRep
rangesrep
                                    n2p :: Maybe (Int, Maybe ScalExp, Maybe ScalExp)
n2p = VName -> RangesRep -> Maybe (Int, Maybe ScalExp, Maybe ScalExp)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
n2 RangesRep
rangesrep
                                in case (Maybe (Int, Maybe ScalExp, Maybe ScalExp)
n1p, Maybe (Int, Maybe ScalExp, Maybe ScalExp)
n2p) of
                                     (Just (Int
p1,Maybe ScalExp
_,Maybe ScalExp
_), Just (Int
p2,Maybe ScalExp
_,Maybe ScalExp
_)) -> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (-Int
p1) (-Int
p2)
                                     (Maybe (Int, Maybe ScalExp, Maybe ScalExp)
_            , Maybe (Int, Maybe ScalExp, Maybe ScalExp)
_            ) -> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int
1::Int) (Int
1::Int)
                     ) [VName]
ids2
    in  case [VName]
ids of
            []  -> Maybe VName
forall a. Maybe a
Nothing
            VName
v:[VName]
_ -> VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v


linearFormScalExp :: VName -> ScalExp -> AlgSimplifyM (Maybe (ScalExp, ScalExp))
linearFormScalExp :: VName -> ScalExp -> AlgSimplifyM (Maybe (ScalExp, ScalExp))
linearFormScalExp VName
sym ScalExp
scl_exp = do
    NNumExp
sofp <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
scl_exp
    Maybe (NNumExp, NNumExp)
ab   <- VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
linearForm VName
sym NNumExp
sofp
    case Maybe (NNumExp, NNumExp)
ab of
        Just (NNumExp
a_sofp, NNumExp
b_sofp) -> do
            ScalExp
a <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
a_sofp
            ScalExp
b <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
b_sofp
            ScalExp
a'<- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
a
            ScalExp
b'<- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
b
            Maybe (ScalExp, ScalExp) -> AlgSimplifyM (Maybe (ScalExp, ScalExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ScalExp, ScalExp)
 -> AlgSimplifyM (Maybe (ScalExp, ScalExp)))
-> Maybe (ScalExp, ScalExp)
-> AlgSimplifyM (Maybe (ScalExp, ScalExp))
forall a b. (a -> b) -> a -> b
$ (ScalExp, ScalExp) -> Maybe (ScalExp, ScalExp)
forall a. a -> Maybe a
Just (ScalExp
a', ScalExp
b')
        Maybe (NNumExp, NNumExp)
Nothing ->
            Maybe (ScalExp, ScalExp) -> AlgSimplifyM (Maybe (ScalExp, ScalExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ScalExp, ScalExp)
forall a. Maybe a
Nothing

linearForm :: VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
linearForm :: VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
linearForm VName
_ (NProd [] PrimType
_) =
    String -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"linearForm: empty Prod!"
linearForm VName
idd ee :: NNumExp
ee@NProd{} = VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
linearForm VName
idd ([NNumExp] -> PrimType -> NNumExp
NSum [NNumExp
ee] (NNumExp -> PrimType
typeOfNAlg NNumExp
ee))
linearForm VName
_ (NSum [] PrimType
_) =
    String -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"linearForm: empty Sum!"
linearForm VName
idd (NSum [NNumExp]
terms PrimType
tp) = do
    [ScalExp]
terms_d_idd <- (NNumExp -> AlgSimplifyM ScalExp)
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM  (\NNumExp
t -> do NNumExp
t0 <- case NNumExp
t of
                                            NProd (ScalExp
_:[ScalExp]
_) PrimType
_ -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return NNumExp
t
                                            NNumExp
_ -> String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"linearForm: ILLEGAL111!!!!"
                                   ScalExp
t_scal <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
t0
                                   ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
t_scal (VName -> PrimType -> ScalExp
Id VName
idd (ScalExp -> PrimType
scalExpType ScalExp
t_scal))
                         ) [NNumExp]
terms
    let myiota :: [Int]
myiota  = [Int
1..([NNumExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [NNumExp]
terms)]
    let ia_terms :: [(Int, ScalExp)]
ia_terms= ((Int, ScalExp) -> Bool) -> [(Int, ScalExp)] -> [(Int, ScalExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Int
_,ScalExp
t)-> case ScalExp
t of
                                     SDiv ScalExp
_ ScalExp
_ -> Bool
False
                                     ScalExp
_           -> Bool
True
                         ) ([Int] -> [ScalExp] -> [(Int, ScalExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
myiota [ScalExp]
terms_d_idd)
    let ([Int]
a_inds, [ScalExp]
a_terms) = [(Int, ScalExp)] -> ([Int], [ScalExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, ScalExp)]
ia_terms

    let ([Int]
_, [NNumExp]
b_terms) = [(Int, NNumExp)] -> ([Int], [NNumExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, NNumExp)] -> ([Int], [NNumExp]))
-> [(Int, NNumExp)] -> ([Int], [NNumExp])
forall a b. (a -> b) -> a -> b
$ ((Int, NNumExp) -> Bool) -> [(Int, NNumExp)] -> [(Int, NNumExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(Int
iii,NNumExp
_) -> Int
iii Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int]
a_inds)
                                      ([Int] -> [NNumExp] -> [(Int, NNumExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
myiota [NNumExp]
terms)
    -- check that b_terms do not contain idd
    Bool
b_succ <- (Bool
 -> NNumExp
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Bool
acc NNumExp
x ->
                        case NNumExp
x of
                           NProd [ScalExp]
fs PrimType
_ -> do let fs_scal :: ScalExp
fs_scal = case [ScalExp]
fs of
                                                            [] -> PrimValue -> ScalExp
Val (PrimValue -> ScalExp) -> PrimValue -> ScalExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
1
                                                            ScalExp
f:[ScalExp]
fs' -> (ScalExp -> ScalExp -> ScalExp) -> ScalExp -> [ScalExp] -> ScalExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ScalExp -> ScalExp -> ScalExp
STimes ScalExp
f [ScalExp]
fs'
                                            let b_ids :: Names
b_ids = ScalExp -> Names
forall a. FreeIn a => a -> Names
freeIn ScalExp
fs_scal
                                            Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ Bool
acc Bool -> Bool -> Bool
&& Bool -> Bool
not (VName
idd VName -> Names -> Bool
`nameIn` Names
b_ids)
                           NNumExp
_          -> String -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"linearForm: ILLEGAL222!!!!"
                    ) Bool
True [NNumExp]
b_terms

    case [ScalExp]
a_terms of
        ScalExp
t:[ScalExp]
ts | Bool
b_succ -> do
            let a_scal :: ScalExp
a_scal = (ScalExp -> ScalExp -> ScalExp) -> ScalExp -> [ScalExp] -> ScalExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
t [ScalExp]
ts
            NNumExp
a_terms_sofp <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
a_scal
            NNumExp
b_terms_sofp <- if [NNumExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [NNumExp]
b_terms
                            then do PrimValue
zero <- PrimType -> AlgSimplifyM PrimValue
getZero PrimType
tp; NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [PrimValue -> ScalExp
Val PrimValue
zero] PrimType
tp
                            else NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum [NNumExp]
b_terms PrimType
tp
            Maybe (NNumExp, NNumExp) -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (NNumExp, NNumExp)
 -> AlgSimplifyM (Maybe (NNumExp, NNumExp)))
-> Maybe (NNumExp, NNumExp)
-> AlgSimplifyM (Maybe (NNumExp, NNumExp))
forall a b. (a -> b) -> a -> b
$ (NNumExp, NNumExp) -> Maybe (NNumExp, NNumExp)
forall a. a -> Maybe a
Just (NNumExp
a_terms_sofp, NNumExp
b_terms_sofp)
        [ScalExp]
_ -> Maybe (NNumExp, NNumExp) -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (NNumExp, NNumExp)
forall a. Maybe a
Nothing

------------------------------------------------
------------------------------------------------
-- Main Helper Function: takes a scalar exp,  --
-- normalizes and simplifies it               --
------------------------------------------------
------------------------------------------------

simplifyScal :: ScalExp -> AlgSimplifyM ScalExp

simplifyScal :: ScalExp -> AlgSimplifyM ScalExp
simplifyScal (Val PrimValue
v) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
v
simplifyScal (Id  VName
x PrimType
t) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
Id  VName
x PrimType
t

simplifyScal e :: ScalExp
e@SNot{} = DNF -> AlgSimplifyM ScalExp
fromDNF (DNF -> AlgSimplifyM ScalExp)
-> AlgSimplifyM DNF -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DNF -> AlgSimplifyM DNF
simplifyDNF (DNF -> AlgSimplifyM DNF) -> AlgSimplifyM DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e
simplifyScal e :: ScalExp
e@SLogAnd{} = DNF -> AlgSimplifyM ScalExp
fromDNF (DNF -> AlgSimplifyM ScalExp)
-> AlgSimplifyM DNF -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DNF -> AlgSimplifyM DNF
simplifyDNF (DNF -> AlgSimplifyM DNF) -> AlgSimplifyM DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e
simplifyScal e :: ScalExp
e@SLogOr{} = DNF -> AlgSimplifyM ScalExp
fromDNF (DNF -> AlgSimplifyM ScalExp)
-> AlgSimplifyM DNF -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DNF -> AlgSimplifyM DNF
simplifyDNF (DNF -> AlgSimplifyM DNF) -> AlgSimplifyM DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e
simplifyScal e :: ScalExp
e@RelExp{} = DNF -> AlgSimplifyM ScalExp
fromDNF (DNF -> AlgSimplifyM ScalExp)
-> AlgSimplifyM DNF -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DNF -> AlgSimplifyM DNF
simplifyDNF (DNF -> AlgSimplifyM DNF) -> AlgSimplifyM DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e

--------------------------------------
--- MaxMin related simplifications ---
--------------------------------------
simplifyScal (MaxMin Bool
_ []) =
    String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"Scalar MaxMin expression with empty arglist."
simplifyScal (MaxMin Bool
_ [ScalExp
e]) = ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e
simplifyScal (MaxMin Bool
ismin [ScalExp]
es) = do -- helperMinMax ismin  es pos
    -- pos <- asks pos
    [ScalExp]
es0 <- (ScalExp -> AlgSimplifyM ScalExp)
-> [ScalExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ScalExp -> AlgSimplifyM ScalExp
simplifyScal [ScalExp]
es
    let evals :: [ScalExp]
evals = (ScalExp -> Bool) -> [ScalExp] -> [ScalExp]
forall a. (a -> Bool) -> [a] -> [a]
filter ScalExp -> Bool
isValue         [ScalExp]
es0
        es' :: [ScalExp]
es'   = (ScalExp -> Bool) -> [ScalExp] -> [ScalExp]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (ScalExp -> Bool) -> ScalExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalExp -> Bool
isValue) [ScalExp]
es0
        mvv :: Maybe ScalExp
mvv = case [ScalExp]
evals of
                []   -> Maybe ScalExp
forall a. Maybe a
Nothing
                ScalExp
v:[ScalExp]
vs -> let myop :: PrimValue -> PrimValue -> PrimValue
myop = if Bool
ismin then PrimValue -> PrimValue -> PrimValue
forall a. Ord a => a -> a -> a
min else PrimValue -> PrimValue -> PrimValue
forall a. Ord a => a -> a -> a
max
                            myval :: PrimValue
myval= ScalExp -> PrimValue
getValue ScalExp
v
                            oneval :: PrimValue
oneval = ((PrimValue -> PrimValue -> PrimValue)
-> PrimValue -> [PrimValue] -> PrimValue
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimValue -> PrimValue -> PrimValue
myop PrimValue
myval ([PrimValue] -> PrimValue)
-> ([ScalExp] -> [PrimValue]) -> [ScalExp] -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ScalExp -> PrimValue) -> [ScalExp] -> [PrimValue]
forall a b. (a -> b) -> [a] -> [b]
map ScalExp -> PrimValue
getValue) [ScalExp]
vs
                        in  ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp) -> ScalExp -> Maybe ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
oneval
    -- flatten the result and remove duplicates,
    -- e.g., Max(Max(e1,e2), e1) -> Max(e1,e2,e3)
    case ([ScalExp]
es', Maybe ScalExp
mvv) of
        ([], Just ScalExp
vv) -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
vv
        ([ScalExp]
_,  Just ScalExp
vv) -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> [ScalExp]
remDups ([ScalExp] -> [ScalExp]) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> a -> b
$ ([ScalExp] -> ScalExp -> [ScalExp])
-> [ScalExp] -> [ScalExp] -> [ScalExp]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [ScalExp] -> ScalExp -> [ScalExp]
flatop [] ([ScalExp] -> [ScalExp]) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> a -> b
$ ScalExp
vvScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
es'
        ([ScalExp]
_,  Maybe ScalExp
Nothing) -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> [ScalExp]
remDups ([ScalExp] -> [ScalExp]) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> a -> b
$ ([ScalExp] -> ScalExp -> [ScalExp])
-> [ScalExp] -> [ScalExp] -> [ScalExp]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl [ScalExp] -> ScalExp -> [ScalExp]
flatop [] [ScalExp]
es'
    -- ToDo: This can prove very expensive as compile time
    --       but, IF e2-e1 <= 0 simplifies to True THEN
    --       Min(e1,e2) = e2.   Code example:
    -- e1me2 <- if isMin
    --          then simplifyScal $ AlgOp MINUS e1 e2 pos
    --          else simplifyScal $ AlgOp MINUS e2 e1 pos
    -- e1me2leq0 <- simplifyNRel $ NRelExp LEQ0 e1me2 pos
    -- case e1me2leq0 of
    --    NAnd [LogCt True  _] _ -> simplifyAlgN e1
    --    NAnd [LogCt False _] _ -> simplifyAlgN e2
    where
        isValue :: ScalExp -> Bool
        isValue :: ScalExp -> Bool
isValue ScalExp
e = case ScalExp
e of
                      Val PrimValue
_ -> Bool
True
                      ScalExp
_     -> Bool
False
        getValue :: ScalExp -> PrimValue
        getValue :: ScalExp -> PrimValue
getValue ScalExp
se = case ScalExp
se of
                        Val PrimValue
v -> PrimValue
v
                        ScalExp
_     -> Int32 -> PrimValue
forall a. IsValue a => a -> PrimValue
value (Int32
0::Int32)
        flatop :: [ScalExp] -> ScalExp -> [ScalExp]
        flatop :: [ScalExp] -> ScalExp -> [ScalExp]
flatop [ScalExp]
a e :: ScalExp
e@(MaxMin Bool
ismin' [ScalExp]
ses) =
            [ScalExp]
a [ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++ if Bool
ismin Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
ismin' then [ScalExp]
ses else [ScalExp
e]
        flatop [ScalExp]
a ScalExp
e = [ScalExp]
a[ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++[ScalExp
e]
        remDups :: [ScalExp] -> [ScalExp]
        remDups :: [ScalExp] -> [ScalExp]
remDups [ScalExp]
l = Set ScalExp -> [ScalExp]
forall a. Set a -> [a]
S.toList ([ScalExp] -> Set ScalExp
forall a. Ord a => [a] -> Set a
S.fromList [ScalExp]
l)

---------------------------------------------------
--- Plus/Minus related simplifications          ---
--- BUG: the MinMax pattern matching should     ---
---      be performed on the simplified subexps ---
---------------------------------------------------
simplifyScal (SPlus ScalExp
e1o ScalExp
e2o) = do
    ScalExp
e1' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1o
    ScalExp
e2' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2o
    if ScalExp -> Bool
isMaxMin ScalExp
e1' Bool -> Bool -> Bool
|| ScalExp -> Bool
isMaxMin ScalExp
e2'
    then ScalExp -> AlgSimplifyM ScalExp
helperPlusMinMax (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e1' ScalExp
e2'
    else ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalPlus ScalExp
e1' ScalExp
e2'

    where
      normalPlus :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
      normalPlus :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalPlus ScalExp
e1 ScalExp
e2 = do
        NNumExp
e1' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e1
        NNumExp
e2' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e2
        let tp :: PrimType
tp = ScalExp -> PrimType
scalExpType ScalExp
e1
        let terms :: [NNumExp]
terms = NNumExp -> [NNumExp]
getTerms NNumExp
e1' [NNumExp] -> [NNumExp] -> [NNumExp]
forall a. [a] -> [a] -> [a]
++ NNumExp -> [NNumExp]
getTerms NNumExp
e2'
        [(NNumExp, PrimValue)]
splittedTerms <- (NNumExp
 -> StateT
      Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue))
-> [NNumExp]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM NNumExp
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
splitTerm [NNumExp]
terms
        let sortedTerms :: [(NNumExp, PrimValue)]
sortedTerms = ((NNumExp, PrimValue) -> (NNumExp, PrimValue) -> Ordering)
-> [(NNumExp, PrimValue)] -> [(NNumExp, PrimValue)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (\(NNumExp
n1,PrimValue
_) (NNumExp
n2,PrimValue
_) -> NNumExp -> NNumExp -> Ordering
forall a. Ord a => a -> a -> Ordering
compare NNumExp
n1 NNumExp
n2) [(NNumExp, PrimValue)]
splittedTerms
        -- foldM discriminate: adds together identical terms, and
        -- we reverse the list, to keep it in a ascending order.
        [(NNumExp, PrimValue)]
merged <- [(NNumExp, PrimValue)] -> [(NNumExp, PrimValue)]
forall a. [a] -> [a]
reverse ([(NNumExp, PrimValue)] -> [(NNumExp, PrimValue)])
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(NNumExp, PrimValue)]
 -> (NNumExp, PrimValue)
 -> StateT
      Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)])
-> [(NNumExp, PrimValue)]
-> [(NNumExp, PrimValue)]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM [(NNumExp, PrimValue)]
-> (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
discriminate [] [(NNumExp, PrimValue)]
sortedTerms
        let filtered :: [(NNumExp, PrimValue)]
filtered = ((NNumExp, PrimValue) -> Bool)
-> [(NNumExp, PrimValue)] -> [(NNumExp, PrimValue)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(NNumExp
_,PrimValue
v) -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
zeroIsh PrimValue
v ) [(NNumExp, PrimValue)]
merged
        if [(NNumExp, PrimValue)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(NNumExp, PrimValue)]
filtered
        then do
            PrimValue
zero <- PrimType -> AlgSimplifyM PrimValue
getZero PrimType
tp
            NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [PrimValue -> ScalExp
Val PrimValue
zero] PrimType
tp
        else do
            [NNumExp]
terms' <- ((NNumExp, PrimValue) -> AlgSimplifyM NNumExp)
-> [(NNumExp, PrimValue)]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (NNumExp, PrimValue) -> AlgSimplifyM NNumExp
joinTerm [(NNumExp, PrimValue)]
filtered
            NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum [NNumExp]
terms' PrimType
tp

simplifyScal (SMinus ScalExp
e1 ScalExp
e2) = do
  let tp :: PrimType
tp = ScalExp -> PrimType
scalExpType ScalExp
e1
  if ScalExp
e1 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
e2
    then PrimValue -> ScalExp
Val (PrimValue -> ScalExp)
-> AlgSimplifyM PrimValue -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> AlgSimplifyM PrimValue
getZero PrimType
tp
    else do PrimValue
min_1 <- PrimType -> AlgSimplifyM PrimValue
getNeg1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e1
            ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e1 (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes (PrimValue -> ScalExp
Val PrimValue
min_1) ScalExp
e2

simplifyScal (SNeg ScalExp
e) = do
    PrimValue
negOne <- PrimType -> AlgSimplifyM PrimValue
getNeg1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
    ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes (PrimValue -> ScalExp
Val PrimValue
negOne) ScalExp
e

simplifyScal (SAbs ScalExp
e) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp
SAbs ScalExp
e

simplifyScal (SSignum ScalExp
e) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp
SSignum ScalExp
e

---------------------------------------------------
--- Times        related simplifications        ---
--- BUG: the MinMax pattern matching should     ---
---      be performed on the simplified subexps ---
---------------------------------------------------
simplifyScal (STimes ScalExp
e1o ScalExp
e2o) = do
    ScalExp
e1'' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1o
    ScalExp
e2'' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2o
    if ScalExp -> Bool
isMaxMin ScalExp
e1'' Bool -> Bool -> Bool
|| ScalExp -> Bool
isMaxMin ScalExp
e2''
    then ScalExp -> AlgSimplifyM ScalExp
helperMultMinMax (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes ScalExp
e1'' ScalExp
e2''
    else ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalTimes ScalExp
e1'' ScalExp
e2''

    where
      normalTimes :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
      normalTimes :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalTimes ScalExp
e1 ScalExp
e2 = do
        let tp :: PrimType
tp = ScalExp -> PrimType
scalExpType ScalExp
e1
        NNumExp
e1' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e1
        NNumExp
e2' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e2
        case (NNumExp
e1', NNumExp
e2') of
            (NProd [ScalExp]
xs PrimType
_, y :: NNumExp
y@NProd{}) -> NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM NNumExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
makeProds [ScalExp]
xs NNumExp
y
            (NProd [ScalExp]
xs PrimType
_, NNumExp
y) -> do
                    [NNumExp]
prods <- (NNumExp -> AlgSimplifyM NNumExp)
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
makeProds [ScalExp]
xs) ([NNumExp]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp])
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall a b. (a -> b) -> a -> b
$ NNumExp -> [NNumExp]
getTerms NNumExp
y
                    NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum ([NNumExp] -> [NNumExp]
forall a. Ord a => [a] -> [a]
sort [NNumExp]
prods) PrimType
tp
            (NNumExp
x, NProd [ScalExp]
ys PrimType
_) -> do
                    [NNumExp]
prods <- (NNumExp -> AlgSimplifyM NNumExp)
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
makeProds [ScalExp]
ys) ([NNumExp]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp])
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall a b. (a -> b) -> a -> b
$ NNumExp -> [NNumExp]
getTerms NNumExp
x
                    NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum ([NNumExp] -> [NNumExp]
forall a. Ord a => [a] -> [a]
sort [NNumExp]
prods) PrimType
tp
            (NSum [NNumExp]
xs PrimType
_, NSum [NNumExp]
ys PrimType
_) -> do
                    [[ScalExp]]
xsMultChildren <- (NNumExp
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp])
-> [NNumExp] -> AlgSimplifyM [[ScalExp]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM NNumExp
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
getMultChildren [NNumExp]
xs
                    [[NNumExp]]
prods <- ([ScalExp]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp])
-> [[ScalExp]]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [[NNumExp]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\[ScalExp]
x -> (NNumExp -> AlgSimplifyM NNumExp)
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
makeProds [ScalExp]
x) [NNumExp]
ys) [[ScalExp]]
xsMultChildren
                    NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum ([NNumExp] -> [NNumExp]
forall a. Ord a => [a] -> [a]
sort ([NNumExp] -> [NNumExp]) -> [NNumExp] -> [NNumExp]
forall a b. (a -> b) -> a -> b
$ [[NNumExp]] -> [NNumExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[NNumExp]]
prods) PrimType
tp

      makeProds :: [ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
      makeProds :: [ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
makeProds [] NNumExp
_ =
           String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
" In simplifyAlgN, makeProds: 1st arg is the empty list! "
      makeProds [ScalExp]
_ (NProd [] PrimType
_) =
          String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
" In simplifyAlgN, makeProds: 2nd arg is the empty list! "
      makeProds [ScalExp]
_ NSum{} =
          String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
" In simplifyAlgN, makeProds: e1 * e2: e2 is a sum of sums! "
      makeProds (Val PrimValue
v1:[ScalExp]
exs) (NProd (Val PrimValue
v2:[ScalExp]
ys) PrimType
tp1) = do
          PrimValue
v <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals PrimValue
v1 PrimValue
v2
          NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (PrimValue -> ScalExp
Val PrimValue
v ScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
: [ScalExp] -> [ScalExp]
forall a. Ord a => [a] -> [a]
sort ([ScalExp]
ys[ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++[ScalExp]
exs) ) PrimType
tp1
      makeProds (Val PrimValue
v:[ScalExp]
exs) (NProd [ScalExp]
ys PrimType
tp1) =
          NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (PrimValue -> ScalExp
Val PrimValue
v ScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
: [ScalExp] -> [ScalExp]
forall a. Ord a => [a] -> [a]
sort ([ScalExp]
ys[ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++[ScalExp]
exs) ) PrimType
tp1
      makeProds [ScalExp]
exs (NProd (Val PrimValue
v : [ScalExp]
ys) PrimType
tp1) =
          NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (PrimValue -> ScalExp
Val PrimValue
v ScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
: [ScalExp] -> [ScalExp]
forall a. Ord a => [a] -> [a]
sort ([ScalExp]
ys[ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++[ScalExp]
exs) ) PrimType
tp1
      makeProds [ScalExp]
exs (NProd [ScalExp]
ys PrimType
tp1) =
          NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd ([ScalExp] -> [ScalExp]
forall a. Ord a => [a] -> [a]
sort ([ScalExp]
ys[ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++[ScalExp]
exs)) PrimType
tp1

---------------------------------------------------
---------------------------------------------------
--- DIvide        related simplifications       ---
---------------------------------------------------
---------------------------------------------------

simplifyScal (SDiv ScalExp
e1o ScalExp
e2o) = do
    ScalExp
e1' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1o
    ScalExp
e2' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2o

    if ScalExp -> Bool
isMaxMin ScalExp
e1' Bool -> Bool -> Bool
|| ScalExp -> Bool
isMaxMin ScalExp
e2'
    then ScalExp -> AlgSimplifyM ScalExp
helperMultMinMax (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
e1' ScalExp
e2'
    else ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalFloatDiv ScalExp
e1' ScalExp
e2'

    where
      normalFloatDiv :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
      normalFloatDiv :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalFloatDiv ScalExp
e1 ScalExp
e2
        | ScalExp
e1 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
e2                  = do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e1
                                         ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
one
--        | e1 == (negateSimplified e2) = do mone<- getNeg1 $ scalExpType e1
--                                           return $ Val mone
        | Bool
otherwise = do
            NNumExp
e1' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e1
            NNumExp
e2' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e2
            case NNumExp
e2' of
              NProd [ScalExp]
fs PrimType
tp -> do
                [(NNumExp, PrimValue)]
e1Split <- (NNumExp
 -> StateT
      Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue))
-> [NNumExp]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM NNumExp
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
splitTerm (NNumExp -> [NNumExp]
getTerms NNumExp
e1')
                case [(NNumExp, PrimValue)]
e1Split of
                  []  -> PrimValue -> ScalExp
Val (PrimValue -> ScalExp)
-> AlgSimplifyM PrimValue -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> AlgSimplifyM PrimValue
getZero PrimType
tp
                  [(NNumExp, PrimValue)]
_   -> do ([ScalExp]
fs', [(NNumExp, PrimValue)]
e1Split')  <- [ScalExp]
-> [ScalExp]
-> [(NNumExp, PrimValue)]
-> AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
trySimplifyDivRec [ScalExp]
fs [] [(NNumExp, PrimValue)]
e1Split
                            if [ScalExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ScalExp]
fs' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [ScalExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ScalExp]
fs
                            then NNumExp -> NNumExp -> AlgSimplifyM ScalExp
turnBackAndDiv NNumExp
e1' NNumExp
e2' -- insuccess
                            else do [NNumExp]
terms_e1' <- ((NNumExp, PrimValue) -> AlgSimplifyM NNumExp)
-> [(NNumExp, PrimValue)]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (NNumExp, PrimValue) -> AlgSimplifyM NNumExp
joinTerm [(NNumExp, PrimValue)]
e1Split'
                                    ScalExp
e1'' <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum [NNumExp]
terms_e1' PrimType
tp
                                    case [ScalExp]
fs' of
                                      [] -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
e1''
                                      [ScalExp]
_  -> do ScalExp
e2'' <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
fs' PrimType
tp
                                               ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
e1'' ScalExp
e2''

              NNumExp
_ -> NNumExp -> NNumExp -> AlgSimplifyM ScalExp
turnBackAndDiv NNumExp
e1' NNumExp
e2'

      turnBackAndDiv :: NNumExp -> NNumExp -> AlgSimplifyM ScalExp
      turnBackAndDiv :: NNumExp -> NNumExp -> AlgSimplifyM ScalExp
turnBackAndDiv NNumExp
ee1 NNumExp
ee2 = do
        ScalExp
ee1' <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
ee1
        ScalExp
ee2' <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
ee2
        ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
ee1' ScalExp
ee2'

simplifyScal (SMod ScalExp
e1o ScalExp
e2o) =
    ScalExp -> ScalExp -> ScalExp
SMod (ScalExp -> ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1o StateT
  Int (ReaderT AlgSimplifyEnv (Either Error)) (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2o

simplifyScal (SQuot ScalExp
e1o ScalExp
e2o) =
    ScalExp -> ScalExp -> ScalExp
SQuot (ScalExp -> ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1o StateT
  Int (ReaderT AlgSimplifyEnv (Either Error)) (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2o

simplifyScal (SRem ScalExp
e1o ScalExp
e2o) =
    ScalExp -> ScalExp -> ScalExp
SRem (ScalExp -> ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1o StateT
  Int (ReaderT AlgSimplifyEnv (Either Error)) (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2o

---------------------------------------------------
---------------------------------------------------
--- Power        related simplifications        ---
---------------------------------------------------
---------------------------------------------------

-- cannot handle 0^a, because if a < 0 then it's an error.
-- Could be extented to handle negative exponents better, if needed
simplifyScal (SPow ScalExp
e1 ScalExp
e2) = do
    let tp :: PrimType
tp = ScalExp -> PrimType
scalExpType ScalExp
e1
    ScalExp
e1' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e1
    ScalExp
e2' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e2

    if ScalExp -> Bool
isCt1 ScalExp
e1' Bool -> Bool -> Bool
|| ScalExp -> Bool
isCt0 ScalExp
e2'
    then PrimValue -> ScalExp
Val (PrimValue -> ScalExp)
-> AlgSimplifyM PrimValue -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> AlgSimplifyM PrimValue
getPos1 PrimType
tp
    else if ScalExp -> Bool
isCt1 ScalExp
e2'
    then ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
e1'
    else case (ScalExp
e1', ScalExp
e2') of
            (Val PrimValue
v1, Val PrimValue
v2)
              | Just PrimValue
v <- PrimValue -> PrimValue -> Maybe PrimValue
powVals PrimValue
v1 PrimValue
v2 -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
v
            (ScalExp
_, Val (IntValue IntValue
n)) ->
                if IntValue -> Int64
P.intToInt64 IntValue
n Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int64
1
                then -- simplifyScal =<< fromNumSofP $ NProd (replicate n e1') tp
                        do ScalExp
new_e <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (Int64 -> ScalExp -> [ScalExp]
forall i a. Integral i => i -> a -> [a]
genericReplicate (IntValue -> Int64
P.intToInt64 IntValue
n) ScalExp
e1') PrimType
tp
                           ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
new_e
                else ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPow ScalExp
e1' ScalExp
e2'
            (ScalExp
_, ScalExp
_) -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPow ScalExp
e1' ScalExp
e2'

    where
        powVals :: PrimValue -> PrimValue -> Maybe PrimValue
        powVals :: PrimValue -> PrimValue -> Maybe PrimValue
powVals (IntValue IntValue
v1) (IntValue IntValue
v2) = IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> Maybe IntValue -> Maybe PrimValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IntValue -> IntValue -> Maybe IntValue
P.doPow IntValue
v1 IntValue
v2
        powVals PrimValue
_ PrimValue
_ = Maybe PrimValue
forall a. Maybe a
Nothing

-----------------------------------------------------
--- Helpers for simplifyScal: MinMax related, etc ---
-----------------------------------------------------

isMaxMin :: ScalExp -> Bool
isMaxMin :: ScalExp -> Bool
isMaxMin MaxMin{} = Bool
True
isMaxMin ScalExp
_        = Bool
False

helperPlusMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperPlusMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperPlusMinMax (SPlus (MaxMin Bool
ismin [ScalExp]
es) ScalExp
e) =
    ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ (ScalExp -> ScalExp) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> [a] -> [b]
map (ScalExp -> ScalExp -> ScalExp
`SPlus` ScalExp
e) [ScalExp]
es
helperPlusMinMax (SPlus ScalExp
e (MaxMin Bool
ismin [ScalExp]
es)) =
    ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ (ScalExp -> ScalExp) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> [a] -> [b]
map (ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e) [ScalExp]
es
helperPlusMinMax ScalExp
_ = String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"helperPlusMinMax: Reached unreachable case!"

{-
helperMinusMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperMinusMinMax (SMinus (MaxMin ismin es) e) =
    simplifyScal $ MaxMin ismin $ map (\x -> SMinus x e) es
helperMinusMinMax (SMinus e (MaxMin ismin es)) =
    simplifyScal $ MaxMin ismin $ map (\x -> SMinus e x) es
helperMinusMinMax _ = do
    pos <- asks pos
    badAlgSimplifyM "helperMinusMinMax: Reached unreachable case!"
-}
helperMultMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperMultMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperMultMinMax (STimes  ScalExp
e em :: ScalExp
em@MaxMin{}) = Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax Bool
True  Bool
True  ScalExp
em ScalExp
e
helperMultMinMax (STimes  em :: ScalExp
em@MaxMin{} ScalExp
e) = Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax Bool
True  Bool
False ScalExp
em ScalExp
e
helperMultMinMax (SDiv ScalExp
e em :: ScalExp
em@MaxMin{}) = Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax Bool
False Bool
True  ScalExp
em ScalExp
e
helperMultMinMax (SDiv em :: ScalExp
em@MaxMin{} ScalExp
e) = Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax Bool
False Bool
False ScalExp
em ScalExp
e
helperMultMinMax ScalExp
_ = String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"helperMultMinMax: Reached unreachable case!"

helperTimesDivMinMax :: Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax :: Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax Bool
isTimes Bool
isRev emo :: ScalExp
emo@MaxMin{} ScalExp
e = do
    ScalExp
em <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
emo
    case ScalExp
em of
        MaxMin Bool
ismin [ScalExp]
es -> do
            ScalExp
e' <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e
            NNumExp
e'_sop <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e'
            BTerm
p' <- BTerm -> AlgSimplifyM BTerm
simplifyNRel (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ RelOp0 -> NNumExp -> BTerm
NRelExp RelOp0
LTH0 NNumExp
e'_sop
            case BTerm
p' of
                LogCt Bool
ctbool -> do
--                    let cond = not isTimes && isRev
--                    let cond'= if ctbool then cond  else not cond
--                    let ismin'= if cond' then ismin else not ismin

                    let cond :: Bool
cond =  (     Bool
isTimes              Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
ctbool ) Bool -> Bool -> Bool
||
                                ( Bool -> Bool
not Bool
isTimes Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
isRev Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
ctbool ) Bool -> Bool -> Bool
||
                                ( Bool -> Bool
not Bool
isTimes Bool -> Bool -> Bool
&&     Bool
isRev Bool -> Bool -> Bool
&&     Bool
ctbool  )
                    let ismin' :: Bool
ismin' = if Bool
cond then Bool
ismin else Bool -> Bool
not Bool
ismin
                    ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> [ScalExp] -> ScalExp
MaxMin Bool
ismin' ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ (ScalExp -> ScalExp) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> [a] -> [b]
map (ScalExp -> ScalExp -> ScalExp
`mkTimesDiv` ScalExp
e') [ScalExp]
es

                BTerm
_  -> if Bool -> Bool
not Bool
isTimes then ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
mkTimesDiv ScalExp
em ScalExp
e'
                      else -- e' * MaxMin{...}
                        case NNumExp
e'_sop of
                            NProd [ScalExp]
_  PrimType
_  -> ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
mkTimesDiv ScalExp
em ScalExp
e' -- simplifyScal =<< fromNumSofP (NProd (em:fs) tp)
                            NSum  [NNumExp]
ts PrimType
tp -> do
                                [NNumExp]
new_ts <-
                                    (NNumExp -> AlgSimplifyM NNumExp)
-> [NNumExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [NNumExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\case
                                             NProd [ScalExp]
fs PrimType
_ -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
emScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
fs) PrimType
tp
                                             NNumExp
_          -> String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM
                                                           String
"helperTimesDivMinMax: SofP invariant violated!"
                                         ) [NNumExp]
ts
                                ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< NNumExp -> AlgSimplifyM ScalExp
fromNumSofP ( [NNumExp] -> PrimType -> NNumExp
NSum [NNumExp]
new_ts PrimType
tp )
        ScalExp
_ -> ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
mkTimesDiv ScalExp
em ScalExp
e
    where
        mkTimesDiv :: ScalExp -> ScalExp -> ScalExp
        mkTimesDiv :: ScalExp -> ScalExp -> ScalExp
mkTimesDiv ScalExp
e1 ScalExp
e2
          | Bool -> Bool
not Bool
isTimes = if Bool
isRev then ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
e2 ScalExp
e1 else ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
e1 ScalExp
e2
          | Bool
isRev       = ScalExp -> ScalExp -> ScalExp
STimes ScalExp
e2 ScalExp
e1
          | Bool
otherwise   = ScalExp -> ScalExp -> ScalExp
STimes ScalExp
e1 ScalExp
e2

helperTimesDivMinMax Bool
_ Bool
_ ScalExp
_ ScalExp
_ =
  String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"helperTimesDivMinMax: Reached unreachable case!"


---------------------------------------------------
---------------------------------------------------
--- translating to and simplifying the          ---
--- disjunctive normal form: toDNF, simplifyDNF ---
---------------------------------------------------
---------------------------------------------------
--isTrueDNF :: DNF -> Bool
--isTrueDNF [[LogCt True]] = True
--isTrueDNF _              = False
--
--getValueDNF :: DNF -> Maybe Bool
--getValueDNF [[LogCt True]]  = Just True
--getValueDNF [[LogCt False]] = Just True
--getValueDNF _               = Nothing


negateBTerm :: BTerm -> AlgSimplifyM BTerm
negateBTerm :: BTerm -> AlgSimplifyM BTerm
negateBTerm (LogCt Bool
v) = BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ Bool -> BTerm
LogCt (Bool -> Bool
not Bool
v)
negateBTerm (PosId VName
i) = BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ VName -> BTerm
NegId VName
i
negateBTerm (NegId VName
i) = BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ VName -> BTerm
PosId VName
i
negateBTerm (NRelExp RelOp0
rel NNumExp
e) = do
    let tp :: PrimType
tp = NNumExp -> PrimType
typeOfNAlg NNumExp
e
    case (PrimType
tp, RelOp0
rel) of
        (IntType IntType
it, RelOp0
LTH0) -> do
            ScalExp
se <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
e
            NNumExp
ne <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp
SNeg (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
se (PrimValue -> ScalExp
Val (IntValue -> PrimValue
forall a. IsValue a => a -> PrimValue
value (IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
P.intValue IntType
it (Int
1::Int)))))
            BTerm -> AlgSimplifyM BTerm
forall (m :: * -> *) a. Monad m => a -> m a
return (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ RelOp0 -> NNumExp -> BTerm
NRelExp RelOp0
LTH0 NNumExp
ne
        (PrimType, RelOp0)
_ -> RelOp0 -> NNumExp -> BTerm
NRelExp (if RelOp0
rel RelOp0 -> RelOp0 -> Bool
forall a. Eq a => a -> a -> Bool
== RelOp0
LEQ0 then RelOp0
LTH0 else RelOp0
LEQ0) (NNumExp -> BTerm) -> AlgSimplifyM NNumExp -> AlgSimplifyM BTerm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
             (ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
negateSimplified (ScalExp -> AlgSimplifyM ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
e)

bterm2ScalExp :: BTerm -> AlgSimplifyM ScalExp
bterm2ScalExp :: BTerm -> AlgSimplifyM ScalExp
bterm2ScalExp (LogCt Bool
v) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val (PrimValue -> ScalExp) -> PrimValue -> ScalExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
v
bterm2ScalExp (PosId VName
i) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
Id VName
i PrimType
int32
bterm2ScalExp (NegId VName
i) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp
SNot (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
Id VName
i PrimType
int32
bterm2ScalExp (NRelExp RelOp0
rel NNumExp
e) = RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
rel (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
e

-- translates from DNF to ScalExp
fromDNF :: DNF -> AlgSimplifyM ScalExp
fromDNF :: DNF -> AlgSimplifyM ScalExp
fromDNF [] = String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"fromDNF: empty DNF!"
fromDNF ([BTerm]
t:DNF
ts) = do
    ScalExp
t' <- [BTerm] -> AlgSimplifyM ScalExp
translFact [BTerm]
t
    (ScalExp -> [BTerm] -> AlgSimplifyM ScalExp)
-> ScalExp -> DNF -> AlgSimplifyM ScalExp
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\ScalExp
acc [BTerm]
x -> do ScalExp
x' <- [BTerm] -> AlgSimplifyM ScalExp
translFact [BTerm]
x; ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SLogOr ScalExp
x' ScalExp
acc) ScalExp
t' DNF
ts
    where
        translFact :: [BTerm] -> AlgSimplifyM ScalExp
translFact [] = String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"fromDNF, translFact empty DNF factor!"
        translFact (BTerm
f:[BTerm]
fs) = do
            ScalExp
f' <- BTerm -> AlgSimplifyM ScalExp
bterm2ScalExp BTerm
f
            (ScalExp -> BTerm -> AlgSimplifyM ScalExp)
-> ScalExp -> [BTerm] -> AlgSimplifyM ScalExp
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\ScalExp
acc BTerm
x -> do ScalExp
x' <- BTerm -> AlgSimplifyM ScalExp
bterm2ScalExp BTerm
x; ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SLogAnd ScalExp
x' ScalExp
acc) ScalExp
f' [BTerm]
fs

-- translates (and simplifies numeric expressions?) to DNF form.
toDNF :: ScalExp -> AlgSimplifyM DNF
toDNF :: ScalExp -> AlgSimplifyM DNF
toDNF (Val  (BoolValue Bool
v)) = DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[Bool -> BTerm
LogCt Bool
v]]
toDNF (Id      VName
idd  PrimType
_ ) = DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[VName -> BTerm
PosId VName
idd]]
toDNF (RelExp  RelOp0
rel  ScalExp
e ) = do
  let t :: PrimType
t = ScalExp -> PrimType
scalExpType ScalExp
e
  case PrimType
t of
    IntType IntType
it -> do
      ScalExp
e' <- if RelOp0
rel RelOp0 -> RelOp0 -> Bool
forall a. Eq a => a -> a -> Bool
== RelOp0
LEQ0
            then do PrimValue
m1 <- PrimType -> AlgSimplifyM PrimValue
getNeg1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
                    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
m1
            else ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
e
      NNumExp
ne   <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e'
      BTerm
nrel <- BTerm -> AlgSimplifyM BTerm
simplifyNRel (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ RelOp0 -> NNumExp -> BTerm
NRelExp RelOp0
LTH0 NNumExp
ne  -- False
      DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[BTerm
nrel]]

    PrimType
_   -> do NNumExp
ne   <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal ScalExp
e
              BTerm
nrel <- AlgSimplifyM BTerm -> AlgSimplifyM BTerm
forall a. AlgSimplifyM a -> AlgSimplifyM a
markGaussLTH0 (AlgSimplifyM BTerm -> AlgSimplifyM BTerm)
-> AlgSimplifyM BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ BTerm -> AlgSimplifyM BTerm
simplifyNRel (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ RelOp0 -> NNumExp -> BTerm
NRelExp RelOp0
rel NNumExp
ne
              DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[BTerm
nrel]]
--
toDNF (SNot (SNot     ScalExp
e)) = ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e
toDNF (SNot (Val (BoolValue Bool
v))) = DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[Bool -> BTerm
LogCt (Bool -> BTerm) -> Bool -> BTerm
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not Bool
v]]
toDNF (SNot (Id VName
idd PrimType
_)) = DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[VName -> BTerm
NegId VName
idd]]
toDNF (SNot (RelExp RelOp0
rel ScalExp
e)) = do
    let not_rel :: RelOp0
not_rel = if RelOp0
rel RelOp0 -> RelOp0 -> Bool
forall a. Eq a => a -> a -> Bool
== RelOp0
LEQ0 then RelOp0
LTH0 else RelOp0
LEQ0
    ScalExp
neg_e <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp
SNeg ScalExp
e)
    ScalExp -> AlgSimplifyM DNF
toDNF (ScalExp -> AlgSimplifyM DNF) -> ScalExp -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
not_rel ScalExp
neg_e
--
toDNF (SLogOr  ScalExp
e1 ScalExp
e2  ) = do
    DNF
e1s <- ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e1
    DNF
e2s <- ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e2
    DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return (DNF -> AlgSimplifyM DNF) -> DNF -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ DNF -> DNF
forall a. Ord a => [a] -> [a]
sort (DNF -> DNF) -> DNF -> DNF
forall a b. (a -> b) -> a -> b
$ DNF
e1s DNF -> DNF -> DNF
forall a. [a] -> [a] -> [a]
++ DNF
e2s
toDNF (SLogAnd ScalExp
e1 ScalExp
e2  ) = do
    -- [t1 ++ t2 | t1 <- toDNF e1, t2 <- toDNF e2]
    DNF
e1s <- ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e1
    DNF
e2s <- ScalExp -> AlgSimplifyM DNF
toDNF ScalExp
e2
    let lll :: [DNF]
lll = ([BTerm] -> DNF) -> DNF -> [DNF]
forall a b. (a -> b) -> [a] -> [b]
map (\[BTerm]
t2-> ([BTerm] -> [BTerm]) -> DNF -> DNF
forall a b. (a -> b) -> [a] -> [b]
map ([BTerm] -> [BTerm] -> [BTerm]
forall a. [a] -> [a] -> [a]
++[BTerm]
t2) DNF
e1s) DNF
e2s
    DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return (DNF -> AlgSimplifyM DNF) -> DNF -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ DNF -> DNF
forall a. Ord a => [a] -> [a]
sort (DNF -> DNF) -> DNF -> DNF
forall a b. (a -> b) -> a -> b
$ [DNF] -> DNF
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [DNF]
lll
toDNF (SNot (SLogAnd ScalExp
e1 ScalExp
e2)) = do
    DNF
e1s <- ScalExp -> AlgSimplifyM DNF
toDNF (ScalExp -> ScalExp
SNot ScalExp
e1)
    DNF
e2s <- ScalExp -> AlgSimplifyM DNF
toDNF (ScalExp -> ScalExp
SNot ScalExp
e2)
    DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return (DNF -> AlgSimplifyM DNF) -> DNF -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ DNF -> DNF
forall a. Ord a => [a] -> [a]
sort (DNF -> DNF) -> DNF -> DNF
forall a b. (a -> b) -> a -> b
$ DNF
e1s DNF -> DNF -> DNF
forall a. [a] -> [a] -> [a]
++ DNF
e2s
toDNF (SNot (SLogOr ScalExp
e1 ScalExp
e2)) = do
    -- [t1 ++ t2 | t1 <- dnf $ SNot e1, t2 <- dnf $ SNot e2]
    DNF
e1s <- ScalExp -> AlgSimplifyM DNF
toDNF (ScalExp -> AlgSimplifyM DNF) -> ScalExp -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp
SNot ScalExp
e1
    DNF
e2s <- ScalExp -> AlgSimplifyM DNF
toDNF (ScalExp -> AlgSimplifyM DNF) -> ScalExp -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp
SNot ScalExp
e2
    let lll :: [DNF]
lll = ([BTerm] -> DNF) -> DNF -> [DNF]
forall a b. (a -> b) -> [a] -> [b]
map (\[BTerm]
t2-> ([BTerm] -> [BTerm]) -> DNF -> DNF
forall a b. (a -> b) -> [a] -> [b]
map ([BTerm] -> [BTerm] -> [BTerm]
forall a. [a] -> [a] -> [a]
++[BTerm]
t2) DNF
e1s) DNF
e2s
    DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return (DNF -> AlgSimplifyM DNF) -> DNF -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ DNF -> DNF
forall a. Ord a => [a] -> [a]
sort (DNF -> DNF) -> DNF -> DNF
forall a b. (a -> b) -> a -> b
$ [DNF] -> DNF
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [DNF]
lll
toDNF ScalExp
_            = String -> AlgSimplifyM DNF
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"toDNF: not a boolean expression!"

------------------------------------------------------
--- Simplifying Boolean Expressions:               ---
---  0. p     AND p == p;       p     OR p == p    ---
---  1. False AND p == False;   True  OR p == True ---
---  2. True  AND p == p;       False OR p == p    ---
---  3.(not p)AND p == FALSE;  (not p)OR p == True ---
---  4. ToDo: p1 AND p2 == p1 if p1 => p2          ---
---           p1 AND p2 == False if p1 => not p2 or---
---                                 p2 => not p1   ---
---     Also: p1 OR p2 == p2 if p1 => p2           ---
---           p1 OR p2 == True if not p1 => p2 or  ---
---                               not p2 => p1     ---
---     This boils down to relations:              ---
---      e1 < 0 => e2 < 0 if e2 <= e1              ---
------------------------------------------------------
simplifyDNF :: DNF -> AlgSimplifyM DNF
simplifyDNF :: DNF -> AlgSimplifyM DNF
simplifyDNF DNF
terms0 = do
    DNF
terms1 <- ([BTerm]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm])
-> DNF -> AlgSimplifyM DNF
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Bool
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
simplifyAndOr Bool
True) DNF
terms0
    let terms' :: DNF
terms' = if [Bool -> BTerm
LogCt Bool
True] [BTerm] -> DNF -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` DNF
terms1 then [[Bool -> BTerm
LogCt Bool
True]]
                 else Set [BTerm] -> DNF
forall a. Set a -> [a]
S.toList (Set [BTerm] -> DNF) -> Set [BTerm] -> DNF
forall a b. (a -> b) -> a -> b
$ DNF -> Set [BTerm]
forall a. Ord a => [a] -> Set a
S.fromList (DNF -> Set [BTerm]) -> DNF -> Set [BTerm]
forall a b. (a -> b) -> a -> b
$
                        ([BTerm] -> Bool) -> DNF -> DNF
forall a. (a -> Bool) -> [a] -> [a]
filter ([BTerm] -> [BTerm] -> Bool
forall a. Eq a => a -> a -> Bool
/= [Bool -> BTerm
LogCt Bool
False]) DNF
terms1
    if DNF -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null DNF
terms' then DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return [[Bool -> BTerm
LogCt Bool
False]]
    else do
        let len1terms :: Bool
len1terms = ([BTerm] -> Bool) -> DNF -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Int
1Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==) (Int -> Bool) -> ([BTerm] -> Int) -> [BTerm] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [BTerm] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) DNF
terms'
        if Bool -> Bool
not Bool
len1terms then DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return DNF
terms'
        else do let terms_flat :: [BTerm]
terms_flat = DNF -> [BTerm]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat DNF
terms'
                [BTerm]
terms'' <- Bool
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
simplifyAndOr Bool
False [BTerm]
terms_flat
                DNF -> AlgSimplifyM DNF
forall (m :: * -> *) a. Monad m => a -> m a
return (DNF -> AlgSimplifyM DNF) -> DNF -> AlgSimplifyM DNF
forall a b. (a -> b) -> a -> b
$ (BTerm -> [BTerm]) -> [BTerm] -> DNF
forall a b. (a -> b) -> [a] -> [b]
map (BTerm -> [BTerm] -> [BTerm]
forall a. a -> [a] -> [a]
:[]) [BTerm]
terms''

-- big helper function for simplifyDNF
simplifyAndOr :: Bool -> [BTerm] -> AlgSimplifyM [BTerm]
simplifyAndOr :: Bool
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
simplifyAndOr Bool
_ [] = String
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"simplifyAndOr: not a boolean expression!"
simplifyAndOr Bool
is_and [BTerm]
fs =
    if Bool -> BTerm
LogCt (Bool -> Bool
not Bool
is_and) BTerm -> [BTerm] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [BTerm]
fs
         -- False AND p == False
    then [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall (m :: * -> *) a. Monad m => a -> m a
return [Bool -> BTerm
LogCt (Bool -> BTerm) -> Bool -> BTerm
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not Bool
is_and]
                    -- (i) p AND p == p,        (ii) True AND p == p
    else do let fs' :: [BTerm]
fs' = Set BTerm -> [BTerm]
forall a. Set a -> [a]
S.toList (Set BTerm -> [BTerm])
-> ([BTerm] -> Set BTerm) -> [BTerm] -> [BTerm]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [BTerm] -> Set BTerm
forall a. Ord a => [a] -> Set a
S.fromList ([BTerm] -> Set BTerm)
-> ([BTerm] -> [BTerm]) -> [BTerm] -> Set BTerm
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BTerm -> Bool) -> [BTerm] -> [BTerm]
forall a. (a -> Bool) -> [a] -> [a]
filter (BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
/=Bool -> BTerm
LogCt Bool
is_and) ([BTerm] -> [BTerm]) -> [BTerm] -> [BTerm]
forall a b. (a -> b) -> a -> b
$ [BTerm]
fs
            if [BTerm] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [BTerm]
fs'
            then [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall (m :: * -> *) a. Monad m => a -> m a
return [Bool -> BTerm
LogCt Bool
is_and]
            else do    -- IF p1 => p2 THEN   p1 AND p2 --> p1
                [BTerm]
fs''<- ([BTerm]
 -> BTerm
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm])
-> [BTerm]
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\[BTerm]
l BTerm
x-> do (Bool
addx, [BTerm]
l') <- Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
trimImplies Bool
is_and BTerm
x [BTerm]
l
                                        [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BTerm]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm])
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall a b. (a -> b) -> a -> b
$ if Bool
addx then BTerm
xBTerm -> [BTerm] -> [BTerm]
forall a. a -> [a] -> [a]
:[BTerm]
l' else [BTerm]
l'
                             ) [] [BTerm]
fs'
                       -- IF p1 => not p2 THEN p1 AND p2 == False
                Bool
isF <- (Bool
 -> BTerm
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Bool
b BTerm
x -> if Bool
b then Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
b
                                      else do BTerm
notx <- BTerm -> AlgSimplifyM BTerm
negateBTerm BTerm
x
                                              Bool
-> BTerm
-> BTerm
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesAny Bool
is_and BTerm
x BTerm
notx [BTerm]
fs''
                             ) Bool
False [BTerm]
fs''
                [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall (m :: * -> *) a. Monad m => a -> m a
return ([BTerm]
 -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm])
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [BTerm]
forall a b. (a -> b) -> a -> b
$ if Bool -> Bool
not Bool
isF then [BTerm]
fs''
                         else if Bool
is_and then [Bool -> BTerm
LogCt Bool
False]
                                        else [Bool -> BTerm
LogCt Bool
True ]
    where
        -- e1 => e2 ?
        impliesRel :: BTerm -> BTerm -> AlgSimplifyM Bool
        impliesRel :: BTerm
-> BTerm -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesRel (LogCt Bool
False) BTerm
_ = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        impliesRel BTerm
_ (LogCt  Bool
True) = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        impliesRel (LogCt Bool
True)  BTerm
e = do
            let e' :: BTerm
e' = BTerm
e -- simplifyNRel e
            Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ BTerm
e' BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> BTerm
LogCt Bool
True
        impliesRel BTerm
e (LogCt Bool
False) = do
            BTerm
e' <- BTerm -> AlgSimplifyM BTerm
negateBTerm BTerm
e -- simplifyNRel =<< negateBTerm e
            Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ BTerm
e' BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> BTerm
LogCt Bool
True
        impliesRel (NRelExp RelOp0
rel1 NNumExp
e1) (NRelExp RelOp0
rel2 NNumExp
e2) = do
        -- ToDo: implement implies relation!
            --not_aggr <- asks cheap
            let btp :: PrimType
btp = NNumExp -> PrimType
typeOfNAlg NNumExp
e1
            if PrimType
btp PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
/= NNumExp -> PrimType
typeOfNAlg NNumExp
e2
            then Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
            else do
                PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1    PrimType
btp
                ScalExp
e1' <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
e1
                ScalExp
e2' <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
e2
                case (RelOp0
rel1, RelOp0
rel2, PrimType
btp) of
                    (RelOp0
LTH0, RelOp0
LTH0, IntType IntType
_) -> do
                        NNumExp
e2me1m1 <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP (ScalExp -> AlgSimplifyM NNumExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> ScalExp -> ScalExp
SMinus ScalExp
e2' (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e1' (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
one)
                        BTerm
diffrel <- BTerm -> AlgSimplifyM BTerm
simplifyNRel (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ RelOp0 -> NNumExp -> BTerm
NRelExp RelOp0
LTH0 NNumExp
e2me1m1
                        Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ BTerm
diffrel BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> BTerm
LogCt Bool
True
                    (RelOp0
_, RelOp0
_, IntType IntType
_) -> String -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"impliesRel: LEQ0 for Int!"
                    (RelOp0
_, RelOp0
_, PrimType
_) -> String -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"impliesRel: exp of illegal type!"
        impliesRel BTerm
p1 BTerm
p2
            | BTerm
p1 BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
== BTerm
p2  = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            | Bool
otherwise = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

        -- trimImplies(true,  x, l) performs: p1 AND p2 == p1 if p1 => p2,
        --   i.e., removes any p in l such that:
        --    (i) x => p orelse if
        --   (ii) p => x then indicates that p should not be added to the reuslt
        -- trimImplies(false, x, l) performs: p1 OR p2 == p2 if p1 => p2,
        --   i.e., removes any p from l such that:
        --    (i) p => x orelse if
        --   (ii) x => p then indicates that p should not be added to the result
        trimImplies :: Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
        trimImplies :: Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
trimImplies Bool
_        BTerm
_ []     = (Bool, [BTerm]) -> AlgSimplifyM (Bool, [BTerm])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, [])
        trimImplies Bool
and_case BTerm
x (BTerm
p:[BTerm]
ps) = do
            Bool
succc <- BTerm
-> BTerm -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesRel BTerm
x BTerm
p
            if Bool
succc
            then if Bool
and_case then Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
trimImplies Bool
and_case BTerm
x [BTerm]
ps else (Bool, [BTerm]) -> AlgSimplifyM (Bool, [BTerm])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, BTerm
pBTerm -> [BTerm] -> [BTerm]
forall a. a -> [a] -> [a]
:[BTerm]
ps)
            else do Bool
suc <- BTerm
-> BTerm -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesRel BTerm
p BTerm
x
                    if Bool
suc then if Bool
and_case then (Bool, [BTerm]) -> AlgSimplifyM (Bool, [BTerm])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, BTerm
pBTerm -> [BTerm] -> [BTerm]
forall a. a -> [a] -> [a]
:[BTerm]
ps) else Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
trimImplies Bool
and_case BTerm
x [BTerm]
ps
                    else do (Bool
addx, [BTerm]
newps) <- Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
trimImplies Bool
and_case BTerm
x [BTerm]
ps
                            (Bool, [BTerm]) -> AlgSimplifyM (Bool, [BTerm])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
addx, BTerm
pBTerm -> [BTerm] -> [BTerm]
forall a. a -> [a] -> [a]
:[BTerm]
newps)

        -- impliesAny(true,  x, notx, l) performs:
        --   x AND p == False if p => not x, where p in l,
        -- impliesAny(true,  x, notx, l) performs:
        --   x OR p == True if not x => p
        -- BUT only when p != x, i.e., avoids comparing x with notx
        impliesAny :: Bool -> BTerm -> BTerm -> [BTerm] -> AlgSimplifyM Bool
        impliesAny :: Bool
-> BTerm
-> BTerm
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesAny Bool
_        BTerm
_ BTerm
_    []     = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        impliesAny Bool
and_case BTerm
x BTerm
notx (BTerm
p:[BTerm]
ps)
            | BTerm
x BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
== BTerm
p = Bool
-> BTerm
-> BTerm
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesAny Bool
and_case BTerm
x BTerm
notx [BTerm]
ps
            | Bool
otherwise = do
                Bool
succ' <- if Bool
and_case then BTerm
-> BTerm -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesRel BTerm
p BTerm
notx else BTerm
-> BTerm -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesRel BTerm
notx BTerm
p
                if Bool
succ' then Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                else Bool
-> BTerm
-> BTerm
-> [BTerm]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
impliesAny Bool
and_case BTerm
x BTerm
notx [BTerm]
ps

------------------------------------------------
--- Syntax-Directed (Brainless) Translators  ---
---    scalExp <-> NNumExp                   ---
--- and negating a scalar expression         ---
------------------------------------------------

-- negates an already simplified scalar expression,
--   presumably more efficient than negating and
--   then simplifying it.
negateSimplified :: ScalExp -> AlgSimplifyM ScalExp
negateSimplified :: ScalExp -> AlgSimplifyM ScalExp
negateSimplified (SNeg ScalExp
e) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
e
negateSimplified (SNot ScalExp
e) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
e
negateSimplified (SAbs ScalExp
e) = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp
SAbs ScalExp
e
negateSimplified (SSignum ScalExp
e) =
  ScalExp -> ScalExp
SSignum (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e
negateSimplified e :: ScalExp
e@(Val PrimValue
v) = do
    PrimValue
m1 <- PrimType -> AlgSimplifyM PrimValue
getNeg1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
    PrimValue
v' <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals PrimValue
m1 PrimValue
v; ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
v'
negateSimplified e :: ScalExp
e@Id{} = do
    PrimValue
m1 <- PrimType -> AlgSimplifyM PrimValue
getNeg1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes (PrimValue -> ScalExp
Val PrimValue
m1) ScalExp
e
negateSimplified (SMinus ScalExp
e1 ScalExp
e2) = do -- return $ SMinus e2 e1
    ScalExp
e1' <- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e1
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e1' ScalExp
e2
negateSimplified (SPlus ScalExp
e1 ScalExp
e2) = do
    ScalExp
e1' <- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e1
    ScalExp
e2' <- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e2
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
e1' ScalExp
e2'
negateSimplified e :: ScalExp
e@(SPow ScalExp
_ ScalExp
_) = do
    PrimValue
m1 <- PrimType -> AlgSimplifyM PrimValue
getNeg1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes (PrimValue -> ScalExp
Val PrimValue
m1) ScalExp
e
negateSimplified (STimes  ScalExp
e1 ScalExp
e2) = do
    (ScalExp
e1', ScalExp
e2') <- ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp)
helperNegateMult ScalExp
e1 ScalExp
e2; ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes  ScalExp
e1' ScalExp
e2'
negateSimplified (SDiv ScalExp
e1 ScalExp
e2) = do
    (ScalExp
e1', ScalExp
e2') <- ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp)
helperNegateMult ScalExp
e1 ScalExp
e2; ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
e1' ScalExp
e2'
negateSimplified (SMod ScalExp
e1 ScalExp
e2) =
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SMod ScalExp
e1 ScalExp
e2
negateSimplified (SQuot ScalExp
e1 ScalExp
e2) = do
    (ScalExp
e1', ScalExp
e2') <- ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp)
helperNegateMult ScalExp
e1 ScalExp
e2; ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SQuot ScalExp
e1' ScalExp
e2'
negateSimplified (SRem ScalExp
e1 ScalExp
e2) =
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SRem ScalExp
e1 ScalExp
e2
negateSimplified (MaxMin Bool
ismin [ScalExp]
ts) =
    Bool -> [ScalExp] -> ScalExp
MaxMin (Bool -> Bool
not Bool
ismin) ([ScalExp] -> ScalExp)
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
-> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ScalExp -> AlgSimplifyM ScalExp)
-> [ScalExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ScalExp -> AlgSimplifyM ScalExp
negateSimplified [ScalExp]
ts
negateSimplified (RelExp RelOp0
LEQ0 ScalExp
e) =
    RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LTH0 (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e
negateSimplified (RelExp RelOp0
LTH0 ScalExp
e) =
    RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LEQ0 (ScalExp -> ScalExp)
-> AlgSimplifyM ScalExp -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e
negateSimplified SLogAnd{} = String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"negateSimplified: SLogAnd unimplemented!"
negateSimplified SLogOr{} = String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"negateSimplified: SLogOr  unimplemented!"

helperNegateMult :: ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp)
helperNegateMult :: ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp)
helperNegateMult ScalExp
e1 ScalExp
e2 =
    case (ScalExp
e1, ScalExp
e2) of
        (Val PrimValue
_,              ScalExp
_) -> do ScalExp
e1'<- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e1;       (ScalExp, ScalExp) -> AlgSimplifyM (ScalExp, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp
e1', ScalExp
e2)
        (STimes (Val PrimValue
v) ScalExp
e1r, ScalExp
_) -> do ScalExp
ev <- ScalExp -> AlgSimplifyM ScalExp
negateSimplified (PrimValue -> ScalExp
Val PrimValue
v);  (ScalExp, ScalExp) -> AlgSimplifyM (ScalExp, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> ScalExp -> ScalExp
STimes ScalExp
ev ScalExp
e1r, ScalExp
e2)
        (ScalExp
_,              Val PrimValue
_) -> do ScalExp
e2'<- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e2;       (ScalExp, ScalExp) -> AlgSimplifyM (ScalExp, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp
e1, ScalExp
e2')
        (ScalExp
_, STimes (Val PrimValue
v) ScalExp
e2r) -> do ScalExp
ev <- ScalExp -> AlgSimplifyM ScalExp
negateSimplified (PrimValue -> ScalExp
Val PrimValue
v);  (ScalExp, ScalExp) -> AlgSimplifyM (ScalExp, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp
e1, ScalExp -> ScalExp -> ScalExp
STimes ScalExp
ev ScalExp
e2r)
        (ScalExp
_,                  ScalExp
_) -> do ScalExp
e1'<- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e1;       (ScalExp, ScalExp) -> AlgSimplifyM (ScalExp, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp
e1', ScalExp
e2)


toNumSofP :: ScalExp -> AlgSimplifyM NNumExp
toNumSofP :: ScalExp -> AlgSimplifyM NNumExp
toNumSofP e :: ScalExp
e@(Val  PrimValue
_) = NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
e] (PrimType -> NNumExp) -> PrimType -> NNumExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
toNumSofP e :: ScalExp
e@(Id VName
_ PrimType
_) = NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
e] (PrimType -> NNumExp) -> PrimType -> NNumExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
toNumSofP e :: ScalExp
e@SDiv{} = NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
e] (PrimType -> NNumExp) -> PrimType -> NNumExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
toNumSofP e :: ScalExp
e@SPow{} = NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
e] (PrimType -> NNumExp) -> PrimType -> NNumExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
e
toNumSofP (SMinus ScalExp
_ ScalExp
_) = String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"toNumSofP: SMinus is not in SofP form!"
toNumSofP (SNeg ScalExp
_) = String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"toNumSofP: SNeg is not in SofP form!"
toNumSofP (STimes ScalExp
e1 ScalExp
e2) = do
    NNumExp
e2' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e2
    case NNumExp
e2' of
        NProd [ScalExp]
es2 PrimType
t -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
e1ScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
es2) PrimType
t
        NNumExp
_ -> String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"toNumSofP: STimes nor in SofP form!"
toNumSofP (SPlus  ScalExp
e1 ScalExp
e2)   = do
    let t :: PrimType
t = ScalExp -> PrimType
scalExpType ScalExp
e1
    NNumExp
e1' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP  ScalExp
e1
    NNumExp
e2' <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP  ScalExp
e2
    case (NNumExp
e1', NNumExp
e2') of
        (NSum [NNumExp]
es1 PrimType
_, NSum [NNumExp]
es2 PrimType
_) -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum ([NNumExp]
es1[NNumExp] -> [NNumExp] -> [NNumExp]
forall a. [a] -> [a] -> [a]
++[NNumExp]
es2) PrimType
t
        (NSum [NNumExp]
es1 PrimType
_, NProd{}) -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum ([NNumExp]
es1[NNumExp] -> [NNumExp] -> [NNumExp]
forall a. [a] -> [a] -> [a]
++[NNumExp
e2']) PrimType
t
        (NProd{}, NSum [NNumExp]
es2 PrimType
_) -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum (NNumExp
e1'NNumExp -> [NNumExp] -> [NNumExp]
forall a. a -> [a] -> [a]
:[NNumExp]
es2)    PrimType
t
        (NProd{}, NProd{}   ) -> NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum [NNumExp
e1', NNumExp
e2']   PrimType
t
toNumSofP me :: ScalExp
me@MaxMin{} =
  NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
me] (PrimType -> NNumExp) -> PrimType -> NNumExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
me
toNumSofP ScalExp
s_e = NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
s_e] (PrimType -> NNumExp) -> PrimType -> NNumExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
s_e


fromNumSofP :: NNumExp -> AlgSimplifyM ScalExp
fromNumSofP :: NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NSum [ ] PrimType
t) =
    PrimValue -> ScalExp
Val (PrimValue -> ScalExp)
-> AlgSimplifyM PrimValue -> AlgSimplifyM ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PrimType -> AlgSimplifyM PrimValue
getZero PrimType
t
fromNumSofP (NSum [NNumExp
f] PrimType
_) = NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
f
fromNumSofP (NSum (NNumExp
f:[NNumExp]
fs) PrimType
t) = do
    ScalExp
fs_e <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [NNumExp] -> PrimType -> NNumExp
NSum [NNumExp]
fs PrimType
t
    ScalExp
f_e  <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP NNumExp
f
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
f_e ScalExp
fs_e
fromNumSofP (NProd [] PrimType
_) =
  String -> AlgSimplifyM ScalExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
" In fromNumSofP, empty NProd expression! "
fromNumSofP (NProd [ScalExp
f] PrimType
_)    = ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return ScalExp
f
fromNumSofP (NProd (ScalExp
f:[ScalExp]
fs) PrimType
t) = do
    ScalExp
fs_e <- NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NNumExp -> AlgSimplifyM ScalExp)
-> NNumExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
fs PrimType
t
    ScalExp -> AlgSimplifyM ScalExp
forall (m :: * -> *) a. Monad m => a -> m a
return (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
STimes ScalExp
f ScalExp
fs_e
--fromNumSofP _ = do
--    pos <- asks pos
--    badAlgSimplifyM "fromNumSofP: unimplemented!"
------------------------------------------------------------
--- Helpers for simplifyScal: getTerms, getMultChildren, ---
---   splitTerm, joinTerm, discriminate
------------------------------------------------------------


-- get the list of terms of an expression
-- BUG for NMinMax -> should convert it back to a ScalExp
getTerms :: NNumExp -> [NNumExp]
getTerms :: NNumExp -> [NNumExp]
getTerms (NSum [NNumExp]
es PrimType
_) = [NNumExp]
es
getTerms e :: NNumExp
e@NProd{} = [NNumExp
e]

-- get the factors of a term
getMultChildren :: NNumExp -> AlgSimplifyM [ScalExp]
getMultChildren :: NNumExp
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
getMultChildren (NSum [NNumExp]
_ PrimType
_) = String
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"getMultChildren, NaryPlus should not be nested 2 levels deep "
getMultChildren (NProd [ScalExp]
xs PrimType
_) = [ScalExp]
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) [ScalExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [ScalExp]
xs

-- split a term into a (multiplicative) value and the rest of the factors.
splitTerm :: NNumExp -> AlgSimplifyM (NNumExp, PrimValue)
splitTerm :: NNumExp
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
splitTerm (NProd [ ] PrimType
_) = String
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"splitTerm: Empty n-ary list of factors."
splitTerm (NProd [ScalExp
f] PrimType
tp) = do
  PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 PrimType
tp
  case ScalExp
f of
      (Val PrimValue
v) -> (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ScalExp] -> PrimType -> NNumExp
NProd [PrimValue -> ScalExp
Val PrimValue
one] PrimType
tp, PrimValue
v  )
      ScalExp
e       -> (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ScalExp] -> PrimType -> NNumExp
NProd [ScalExp
e]       PrimType
tp, PrimValue
one)
splitTerm ne :: NNumExp
ne@(NProd (ScalExp
f:[ScalExp]
fs) PrimType
tp) =
  case ScalExp
f of
      (Val PrimValue
v) -> (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
fs PrimType
tp, PrimValue
v)
      ScalExp
_       -> do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 PrimType
tp
                    (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp
ne, PrimValue
one)
splitTerm NNumExp
e = do
  PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 (NNumExp -> PrimType
typeOfNAlg NNumExp
e)
  (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) (NNumExp, PrimValue)
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp
e, PrimValue
one)

-- join a value with a list of factors into a term.
joinTerm :: (NNumExp, PrimValue) -> AlgSimplifyM NNumExp
joinTerm :: (NNumExp, PrimValue) -> AlgSimplifyM NNumExp
joinTerm ( NSum [NNumExp]
_ PrimType
_, PrimValue
_) = String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"joinTerm: NaryPlus two levels deep."
joinTerm ( NProd [] PrimType
_, PrimValue
_) = String -> AlgSimplifyM NNumExp
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"joinTerm: Empty NaryProd."
joinTerm ( NProd (Val PrimValue
l:[ScalExp]
fs) PrimType
tp, PrimValue
v) = do
    PrimValue
v' <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals PrimValue
v PrimValue
l
    let v'Lit :: ScalExp
v'Lit = PrimValue -> ScalExp
Val PrimValue
v'
    NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
v'LitScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp] -> [ScalExp]
forall a. Ord a => [a] -> [a]
sort [ScalExp]
fs) PrimType
tp
joinTerm ( e :: NNumExp
e@(NProd [ScalExp]
fs PrimType
tp), PrimValue
v)
  | PrimValue -> Bool
P.oneIsh PrimValue
v   = NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return NNumExp
e
  | Bool
otherwise = let vExp :: ScalExp
vExp = PrimValue -> ScalExp
Val PrimValue
v
                in NNumExp -> AlgSimplifyM NNumExp
forall (m :: * -> *) a. Monad m => a -> m a
return (NNumExp -> AlgSimplifyM NNumExp)
-> NNumExp -> AlgSimplifyM NNumExp
forall a b. (a -> b) -> a -> b
$ [ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
vExpScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp] -> [ScalExp]
forall a. Ord a => [a] -> [a]
sort [ScalExp]
fs) PrimType
tp

-- adds up the values corresponding to identical factors!
discriminate :: [(NNumExp, PrimValue)] -> (NNumExp, PrimValue) -> AlgSimplifyM [(NNumExp, PrimValue)]
discriminate :: [(NNumExp, PrimValue)]
-> (NNumExp, PrimValue)
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
discriminate []          (NNumExp, PrimValue)
e        = [(NNumExp, PrimValue)]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (m :: * -> *) a. Monad m => a -> m a
return [(NNumExp, PrimValue)
e]
discriminate e :: [(NNumExp, PrimValue)]
e@((NNumExp
k,PrimValue
v):[(NNumExp, PrimValue)]
t) (NNumExp
k', PrimValue
v') =
  if NNumExp
k NNumExp -> NNumExp -> Bool
forall a. Eq a => a -> a -> Bool
== NNumExp
k'
  then do PrimValue
v'' <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
addVals PrimValue
v PrimValue
v'
          [(NNumExp, PrimValue)]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (m :: * -> *) a. Monad m => a -> m a
return ( (NNumExp
k, PrimValue
v'') (NNumExp, PrimValue)
-> [(NNumExp, PrimValue)] -> [(NNumExp, PrimValue)]
forall a. a -> [a] -> [a]
: [(NNumExp, PrimValue)]
t )
  else [(NNumExp, PrimValue)]
-> StateT
     Int (ReaderT AlgSimplifyEnv (Either Error)) [(NNumExp, PrimValue)]
forall (m :: * -> *) a. Monad m => a -> m a
return ( (NNumExp
k', PrimValue
v') (NNumExp, PrimValue)
-> [(NNumExp, PrimValue)] -> [(NNumExp, PrimValue)]
forall a. a -> [a] -> [a]
: [(NNumExp, PrimValue)]
e )

------------------------------------------------------
--- Trivial Utility Functions                      ---
------------------------------------------------------

getZero :: PrimType -> AlgSimplifyM PrimValue
getZero :: PrimType -> AlgSimplifyM PrimValue
getZero (IntType IntType
t)     = PrimValue -> AlgSimplifyM PrimValue
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimValue -> AlgSimplifyM PrimValue)
-> PrimValue -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
forall a. IsValue a => a -> PrimValue
value (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0::Int)
getZero PrimType
tp      = String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM (String
"getZero for type: "String -> String -> String
forall a. [a] -> [a] -> [a]
++PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
tp)

getPos1 :: PrimType -> AlgSimplifyM PrimValue
getPos1 :: PrimType -> AlgSimplifyM PrimValue
getPos1 (IntType IntType
t)     = PrimValue -> AlgSimplifyM PrimValue
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimValue -> AlgSimplifyM PrimValue)
-> PrimValue -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
forall a. IsValue a => a -> PrimValue
value (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1::Int)
getPos1 PrimType
tp      = String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM (String
"getOne for type: "String -> String -> String
forall a. [a] -> [a] -> [a]
++PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
tp)

getNeg1 :: PrimType -> AlgSimplifyM PrimValue
getNeg1 :: PrimType -> AlgSimplifyM PrimValue
getNeg1 (IntType IntType
t)     = PrimValue -> AlgSimplifyM PrimValue
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimValue -> AlgSimplifyM PrimValue)
-> PrimValue -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
forall a. IsValue a => a -> PrimValue
value (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (-Int
1::Int)
getNeg1 PrimType
tp      = String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM (String
"getOne for type: "String -> String -> String
forall a. [a] -> [a] -> [a]
++PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
tp)

valLTHEQ0 :: RelOp0 -> PrimValue -> AlgSimplifyM Bool
valLTHEQ0 :: RelOp0
-> PrimValue
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
valLTHEQ0 RelOp0
LEQ0 (IntValue IntValue
iv) = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ IntValue -> Int64
P.intToInt64 IntValue
iv Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Int64
0
valLTHEQ0 RelOp0
LTH0 (IntValue IntValue
iv) = Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ IntValue -> Int64
P.intToInt64 IntValue
iv Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
0
valLTHEQ0 RelOp0
_ PrimValue
_ = String -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"valLTHEQ0 for non-numeric type!"

isCt1 :: ScalExp -> Bool
isCt1 :: ScalExp -> Bool
isCt1 (Val PrimValue
v) = PrimValue -> Bool
P.oneIsh PrimValue
v
isCt1 ScalExp
_ = Bool
False

isCt0 :: ScalExp -> Bool
isCt0 :: ScalExp -> Bool
isCt0 (Val PrimValue
v) = PrimValue -> Bool
P.zeroIsh PrimValue
v
isCt0 ScalExp
_       = Bool
False


addVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
addVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
addVals (IntValue IntValue
v1) (IntValue IntValue
v2) =
  PrimValue -> AlgSimplifyM PrimValue
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimValue -> AlgSimplifyM PrimValue)
-> PrimValue -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> IntValue -> IntValue
P.doAdd IntValue
v1 IntValue
v2
addVals PrimValue
_ PrimValue
_ =
  String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"addVals: operands not of (the same) numeral type! "

mulVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals (IntValue IntValue
v1) (IntValue IntValue
v2) =
  PrimValue -> AlgSimplifyM PrimValue
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimValue -> AlgSimplifyM PrimValue)
-> PrimValue -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> IntValue -> IntValue
P.doMul IntValue
v1 IntValue
v2
mulVals PrimValue
v1 PrimValue
v2 =
  String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM (String -> AlgSimplifyM PrimValue)
-> String -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ String
"mulVals: operands not of (the same) numeral type! "String -> String -> String
forall a. [a] -> [a] -> [a]
++
  PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v1String -> String -> String
forall a. [a] -> [a] -> [a]
++String
" "String -> String -> String
forall a. [a] -> [a] -> [a]
++PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v2

divVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
divVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
divVals (IntValue IntValue
v1) (IntValue IntValue
v2) =
  case IntValue -> IntValue -> Maybe IntValue
P.doSDiv IntValue
v1 IntValue
v2 of
    Just IntValue
v -> PrimValue -> AlgSimplifyM PrimValue
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimValue -> AlgSimplifyM PrimValue)
-> PrimValue -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue IntValue
v
    Maybe IntValue
Nothing -> String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"Division by zero"
divVals PrimValue
_ PrimValue
_ =
  String -> AlgSimplifyM PrimValue
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"divVals: operands not of (the same) numeral type! "

canDivValsEvenly :: PrimValue -> PrimValue -> AlgSimplifyM Bool
canDivValsEvenly :: PrimValue
-> PrimValue
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
canDivValsEvenly (IntValue IntValue
v1) (IntValue IntValue
v2) =
  case IntValue -> IntValue -> Maybe IntValue
P.doSMod IntValue
v1 IntValue
v2 of
    Just IntValue
v -> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool)
-> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
P.zeroIsh (PrimValue -> Bool) -> PrimValue -> Bool
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue IntValue
v
    Maybe IntValue
Nothing -> Bool -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
canDivValsEvenly PrimValue
_ PrimValue
_ =
  String -> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"canDivValsEvenly: operands not of (the same) numeral type!"

-------------------------------------------------------------
-------------------------------------------------------------
---- Helpers for the ScalExp and NNumRelLogExp Datatypes ----
-------------------------------------------------------------
-------------------------------------------------------------

typeOfNAlg :: NNumExp -> PrimType
typeOfNAlg :: NNumExp -> PrimType
typeOfNAlg (NSum   [NNumExp]
_   PrimType
t) = PrimType
t
typeOfNAlg (NProd  [ScalExp]
_   PrimType
t) = PrimType
t

----------------------------------------
---- Helpers for Division and Power ----
----------------------------------------
trySimplifyDivRec :: [ScalExp] -> [ScalExp] -> [(NNumExp, PrimValue)] ->
                     AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
trySimplifyDivRec :: [ScalExp]
-> [ScalExp]
-> [(NNumExp, PrimValue)]
-> AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
trySimplifyDivRec [] [ScalExp]
fs' [(NNumExp, PrimValue)]
spl_terms =
    ([ScalExp], [(NNumExp, PrimValue)])
-> AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([ScalExp]
fs', [(NNumExp, PrimValue)]
spl_terms)
trySimplifyDivRec (ScalExp
f:[ScalExp]
fs) [ScalExp]
fs' [(NNumExp, PrimValue)]
spl_terms = do
    [(Bool, (NNumExp, PrimValue))]
res_tmp <- ((NNumExp, PrimValue)
 -> StateT
      Int
      (ReaderT AlgSimplifyEnv (Either Error))
      (Bool, (NNumExp, PrimValue)))
-> [(NNumExp, PrimValue)]
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     [(Bool, (NNumExp, PrimValue))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ScalExp
-> (NNumExp, PrimValue)
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
tryDivProdByOneFact ScalExp
f) [(NNumExp, PrimValue)]
spl_terms
    let ([Bool]
succs, [(NNumExp, PrimValue)]
spl_terms') = [(Bool, (NNumExp, PrimValue))] -> ([Bool], [(NNumExp, PrimValue)])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Bool, (NNumExp, PrimValue))]
res_tmp
    if (Bool -> Bool) -> [Bool] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
==Bool
True) [Bool]
succs
    then [ScalExp]
-> [ScalExp]
-> [(NNumExp, PrimValue)]
-> AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
trySimplifyDivRec [ScalExp]
fs [ScalExp]
fs' [(NNumExp, PrimValue)]
spl_terms'
    else [ScalExp]
-> [ScalExp]
-> [(NNumExp, PrimValue)]
-> AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
trySimplifyDivRec [ScalExp]
fs ([ScalExp]
fs'[ScalExp] -> [ScalExp] -> [ScalExp]
forall a. [a] -> [a] -> [a]
++[ScalExp
f]) [(NNumExp, PrimValue)]
spl_terms


tryDivProdByOneFact :: ScalExp -> (NNumExp, PrimValue) -> AlgSimplifyM (Bool, (NNumExp, PrimValue))
tryDivProdByOneFact :: ScalExp
-> (NNumExp, PrimValue)
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
tryDivProdByOneFact (Val PrimValue
f) (NNumExp
e, PrimValue
v) = do
    Bool
succc <- PrimValue
-> PrimValue
-> StateT Int (ReaderT AlgSimplifyEnv (Either Error)) Bool
canDivValsEvenly PrimValue
v PrimValue
f
    if Bool
succc then do PrimValue
vres <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
divVals PrimValue
v PrimValue
f
                     (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, (NNumExp
e, PrimValue
vres))
             else (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False,(NNumExp
e, PrimValue
v) )

tryDivProdByOneFact ScalExp
_ pev :: (NNumExp, PrimValue)
pev@(NProd [] PrimType
_, PrimValue
_) = (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, (NNumExp, PrimValue)
pev)
tryDivProdByOneFact ScalExp
f pev :: (NNumExp, PrimValue)
pev@(NProd (ScalExp
t:[ScalExp]
tfs) PrimType
tp, PrimValue
v) = do
    (Bool
succc, ScalExp
newt) <- ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp)
tryDivTriv ScalExp
t ScalExp
f
    PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 PrimType
tp
    if Bool -> Bool
not Bool
succc
    then do (Bool
succ', (NNumExp
tfs', PrimValue
v')) <- ScalExp
-> (NNumExp, PrimValue)
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
tryDivProdByOneFact ScalExp
f ([ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
tfs PrimType
tp, PrimValue
v)
            case (Bool
succ', NNumExp
tfs') of
                (Bool
True,  NProd (Val PrimValue
vv:[ScalExp]
tfs'') PrimType
_) -> do
                                    PrimValue
vres <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals PrimValue
v' PrimValue
vv
                                    (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, ([ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
tScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
tfs'') PrimType
tp, PrimValue
vres))
                (Bool
True,  NProd [ScalExp]
tfs'' PrimType
_) -> (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, ([ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
tScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
tfs'') PrimType
tp, PrimValue
v'))
                (Bool
_, NNumExp
_) -> (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, (NNumExp, PrimValue)
pev)
    else case (ScalExp
newt, [ScalExp]
tfs) of
           (Val PrimValue
vv, [ScalExp]
_) -> do PrimValue
vres <- PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals PrimValue
vv PrimValue
v
                             (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Bool, (NNumExp, PrimValue))
 -> StateT
      Int
      (ReaderT AlgSimplifyEnv (Either Error))
      (Bool, (NNumExp, PrimValue)))
-> (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall a b. (a -> b) -> a -> b
$ if [ScalExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [ScalExp]
tfs
                                      then (Bool
True, ([ScalExp] -> PrimType -> NNumExp
NProd [PrimValue -> ScalExp
Val PrimValue
one] PrimType
tp, PrimValue
vres))
                                      else (Bool
True, ([ScalExp] -> PrimType -> NNumExp
NProd [ScalExp]
tfs PrimType
tp, PrimValue
vres))
           (ScalExp
_,      [ScalExp]
_) -> (Bool, (NNumExp, PrimValue))
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, ([ScalExp] -> PrimType -> NNumExp
NProd (ScalExp
newtScalExp -> [ScalExp] -> [ScalExp]
forall a. a -> [a] -> [a]
:[ScalExp]
tfs) PrimType
tp, PrimValue
v))

tryDivProdByOneFact ScalExp
_ (NSum [NNumExp]
_ PrimType
_, PrimValue
_) =
  String
-> StateT
     Int
     (ReaderT AlgSimplifyEnv (Either Error))
     (Bool, (NNumExp, PrimValue))
forall a. String -> AlgSimplifyM a
badAlgSimplifyM String
"tryDivProdByOneFact: unreachable case NSum reached!"


tryDivTriv :: ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp)
tryDivTriv :: ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp)
tryDivTriv (SPow ScalExp
a ScalExp
e1) (SPow ScalExp
d ScalExp
e2)
    | ScalExp
a ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
d Bool -> Bool -> Bool
&& ScalExp
e1 ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
e2 = do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
a
                              (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, PrimValue -> ScalExp
Val PrimValue
one)
    | ScalExp
a ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
d = do
          let tp :: PrimType
tp = ScalExp -> PrimType
scalExpType ScalExp
a
          PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 PrimType
tp
          ScalExp
e1me2 <- ScalExp -> AlgSimplifyM ScalExp
simplifyScal (ScalExp -> AlgSimplifyM ScalExp)
-> ScalExp -> AlgSimplifyM ScalExp
forall a b. (a -> b) -> a -> b
$ ScalExp -> ScalExp -> ScalExp
SMinus ScalExp
e1 ScalExp
e2
          case (PrimType
tp, ScalExp
e1me2) of
            (IntType IntType
_, Val PrimValue
v) | PrimValue -> Bool
P.zeroIsh PrimValue
v ->
              (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, PrimValue -> ScalExp
Val PrimValue
one)
            (IntType IntType
_, Val PrimValue
v) | PrimValue -> Bool
P.oneIsh PrimValue
v ->
              (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, ScalExp
a)
            (IntType IntType
_, ScalExp
_) -> do
              ScalExp
e2me1 <- ScalExp -> AlgSimplifyM ScalExp
negateSimplified ScalExp
e1me2
              NNumExp
e2me1_sop <- ScalExp -> AlgSimplifyM NNumExp
toNumSofP ScalExp
e2me1
              BTerm
p' <- BTerm -> AlgSimplifyM BTerm
simplifyNRel (BTerm -> AlgSimplifyM BTerm) -> BTerm -> AlgSimplifyM BTerm
forall a b. (a -> b) -> a -> b
$ RelOp0 -> NNumExp -> BTerm
NRelExp RelOp0
LTH0 NNumExp
e2me1_sop
              (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp))
-> (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall a b. (a -> b) -> a -> b
$ if BTerm
p' BTerm -> BTerm -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> BTerm
LogCt Bool
True
                       then (Bool
True,  ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1me2)
                       else (Bool
False, ScalExp -> ScalExp -> ScalExp
SDiv (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1) (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
d ScalExp
e2))

            (PrimType
_, ScalExp
_) -> (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, ScalExp -> ScalExp -> ScalExp
SDiv (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1) (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
d ScalExp
e2))

    | Bool
otherwise = (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, ScalExp -> ScalExp -> ScalExp
SDiv (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1) (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
d ScalExp
e2))

tryDivTriv (SPow ScalExp
a ScalExp
e1) ScalExp
b
    | ScalExp
a ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
b = do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
a
                  ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp)
tryDivTriv (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1) (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a (PrimValue -> ScalExp
Val PrimValue
one))
    | Bool
otherwise = (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, ScalExp -> ScalExp -> ScalExp
SDiv (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1) ScalExp
b)

tryDivTriv ScalExp
b (SPow ScalExp
a ScalExp
e1)
    | ScalExp
a ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
b = do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
a
                  ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp)
tryDivTriv (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a (PrimValue -> ScalExp
Val PrimValue
one)) (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1)
    | Bool
otherwise = (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
b (ScalExp -> ScalExp -> ScalExp
SPow ScalExp
a ScalExp
e1))

tryDivTriv ScalExp
t ScalExp
f
    | ScalExp
t ScalExp -> ScalExp -> Bool
forall a. Eq a => a -> a -> Bool
== ScalExp
f    = do PrimValue
one <- PrimType -> AlgSimplifyM PrimValue
getPos1 (PrimType -> AlgSimplifyM PrimValue)
-> PrimType -> AlgSimplifyM PrimValue
forall a b. (a -> b) -> a -> b
$ ScalExp -> PrimType
scalExpType ScalExp
t
                     (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True,  PrimValue -> ScalExp
Val PrimValue
one)
    | Bool
otherwise = (Bool, ScalExp) -> AlgSimplifyM (Bool, ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, ScalExp -> ScalExp -> ScalExp
SDiv ScalExp
t ScalExp
f)


{-
mkRelExp :: Int -> (RangesRep, ScalExp, ScalExp)
mkRelExp 1 =
    let (i',j',n',p',m') = (tident "int i", tident "int j", tident "int n", tident "int p", tident "int m")
        (i,j,n,p,m) = (Id i', Id j', Id n', Id p', Id m')
        one = Val (IntVal 1)
        min_p_nm1 = MaxMin True [p, SMinus n one]
        hash = M.fromList $ [ (identName n', ( 1::Int, Just (Val (IntVal 1)), Nothing ) ),
                               (identName p', ( 1::Int, Just (Val (IntVal 0)), Nothing ) ),
                               (identName i', ( 5::Int, Just (Val (IntVal 0)), Just min_p_nm1 ) )
                             , (identName j', ( 9::Int, Just (Val (IntVal 0)), Just i ) )
                             ] -- M.Map VName (Int, Maybe ScalExp, Maybe ScalExp)
        ij_p_j_p_1_m_m = SMinus (SPlus (STimes i j) (SPlus j one)) m
        rel1 = RelExp LTH0 ij_p_j_p_1_m_m
        m_ij_m_j_m_2 = SNeg ( SPlus (STimes i j) (SPlus j (Val (IntVal 2))) )
        rel2 = RelExp LTH0 m_ij_m_j_m_2

    in (hash, rel1, rel2)
mkRelExp 2 =
    let (i',a',b',l',u') = (tident "int i", tident "int a", tident "int b", tident "int l", tident "int u")
        (i,a,b,l,u) = (Id i', Id a', Id b', Id l', Id u')
        hash = M.fromList $ [ (identName i', ( 5::Int, Just l, Just u ) ) ]
        ai_p_b = SPlus (STimes a i) b
        rel1 = RelExp LTH0 ai_p_b

    in (hash, rel1, rel1)
mkRelExp 3 =
    let (i',j',n',m') = (tident "int i", tident "int j", tident "int n", tident "int m")
        (i,j,n,m) = (Id i', Id j', Id n', Id m')
        one = Val (IntVal 1)
        two = Val (IntVal 2)
        min_j_nm1 = MaxMin True [MaxMin False [Val (IntVal 0), SMinus i (STimes two n)], SMinus n one]
        hash = M.fromList $ [ (identName n', ( 1::Int, Just (Val (IntVal 1)), Nothing ) ),
                               (identName m', ( 2::Int, Just (Val (IntVal 1)), Nothing ) ),
                               (identName i', ( 5::Int, Just (Val (IntVal 0)), Just (SMinus m one) ) )
                             , (identName j', ( 9::Int, Just (Val (IntVal 0)), Just min_j_nm1 ) )
                             ] -- M.Map VName (Int, Maybe ScalExp, Maybe ScalExp)
        ij_m_m = SMinus (STimes i j) m
        rel1 = RelExp LTH0 ij_m_m
--        rel3 = RelExp LTH0 (SMinus i (SPlus (STimes two n) j))
        m_ij_m_1 = SMinus (Val (IntVal (-1))) (STimes i j)
        rel2 = RelExp LTH0 m_ij_m_1

--        simpl_exp = SDiv (MaxMin True [SMinus (Val (IntVal 0)) (STimes i j), SNeg (STimes i n) ])
--                         (STimes i j)
--        rel4 = RelExp LTH0 simpl_exp

    in (hash, rel1, rel2)

mkRelExp _ = let hash = M.empty
                 rel = RelExp LTH0 (Val (IntVal (-1)))
             in (hash, rel, rel)
-}