module Futhark.Analysis.AlgSimplify
  ( Prod (..),
    SofP,
    simplify0,
    simplify,
    simplify',
    simplifySofP,
    simplifySofP',
    sumOfProducts,
    sumToExp,
    prodToExp,
    add,
    sub,
    negate,
    isMultipleOf,
    maybeDivide,
    removeLessThans,
    lessThanish,
    compareComplexity,
  )
where

import Data.Bits (xor)
import Data.Function ((&))
import Data.List (findIndex, intersect, partition, sort, (\\))
import Data.Maybe (mapMaybe)
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop.Names
import Futhark.IR.Syntax.Core
import Futhark.Util
import Futhark.Util.Pretty
import Prelude hiding (negate)

type Exp = PrimExp VName

type TExp = TPrimExp Int64 VName

data Prod = Prod
  { Prod -> Bool
negated :: Bool,
    Prod -> [Exp]
atoms :: [Exp]
  }
  deriving (Int -> Prod -> ShowS
SofP -> ShowS
Prod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: SofP -> ShowS
$cshowList :: SofP -> ShowS
show :: Prod -> String
$cshow :: Prod -> String
showsPrec :: Int -> Prod -> ShowS
$cshowsPrec :: Int -> Prod -> ShowS
Show, Prod -> Prod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Prod -> Prod -> Bool
$c/= :: Prod -> Prod -> Bool
== :: Prod -> Prod -> Bool
$c== :: Prod -> Prod -> Bool
Eq, Eq Prod
Prod -> Prod -> Bool
Prod -> Prod -> Ordering
Prod -> Prod -> Prod
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Prod -> Prod -> Prod
$cmin :: Prod -> Prod -> Prod
max :: Prod -> Prod -> Prod
$cmax :: Prod -> Prod -> Prod
>= :: Prod -> Prod -> Bool
$c>= :: Prod -> Prod -> Bool
> :: Prod -> Prod -> Bool
$c> :: Prod -> Prod -> Bool
<= :: Prod -> Prod -> Bool
$c<= :: Prod -> Prod -> Bool
< :: Prod -> Prod -> Bool
$c< :: Prod -> Prod -> Bool
compare :: Prod -> Prod -> Ordering
$ccompare :: Prod -> Prod -> Ordering
Ord)

type SofP = [Prod]

sumOfProducts :: Exp -> SofP
sumOfProducts :: Exp -> SofP
sumOfProducts = forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
sortProduct forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
sumOfProducts'

sortProduct :: Prod -> Prod
sortProduct :: Prod -> Prod
sortProduct (Prod Bool
n [Exp]
as) = Bool -> [Exp] -> Prod
Prod Bool
n forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> [a]
sort [Exp]
as

sumOfProducts' :: Exp -> SofP
sumOfProducts' :: Exp -> SofP
sumOfProducts' (BinOpExp (Add IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> SofP
sumOfProducts' Exp
e2
sumOfProducts' (BinOpExp (Sub IntType
Int64 Overflow
_) (ValueExp (IntValue (Int64Value Int64
0))) Exp
e) =
  forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate forall a b. (a -> b) -> a -> b
$ Exp -> SofP
sumOfProducts' Exp
e
sumOfProducts' (BinOpExp (Sub IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate (Exp -> SofP
sumOfProducts' Exp
e2)
sumOfProducts' (BinOpExp (Mul IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 SofP -> SofP -> SofP
`mult` Exp -> SofP
sumOfProducts' Exp
e2
sumOfProducts' (ValueExp (IntValue (Int64Value Int64
i))) =
  [Bool -> [Exp] -> Prod
Prod (Int64
i forall a. Ord a => a -> a -> Bool
< Int64
0) [forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Int64
i]]
sumOfProducts' Exp
e = [Bool -> [Exp] -> Prod
Prod Bool
False [Exp
e]]

mult :: SofP -> SofP -> SofP
mult :: SofP -> SofP -> SofP
mult SofP
xs SofP
ys = [Bool -> [Exp] -> Prod
Prod (Bool
b forall a. Bits a => a -> a -> a
`xor` Bool
b') ([Exp]
x forall a. Semigroup a => a -> a -> a
<> [Exp]
y) | Prod Bool
b [Exp]
x <- SofP
xs, Prod Bool
b' [Exp]
y <- SofP
ys]

negate :: Prod -> Prod
negate :: Prod -> Prod
negate Prod
p = Prod
p {negated :: Bool
negated = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Prod -> Bool
negated Prod
p}

sumToExp :: SofP -> Exp
sumToExp :: SofP -> Exp
sumToExp [] = Int64 -> Exp
val Int64
0
sumToExp [Prod
x] = Prod -> Exp
prodToExp Prod
x
sumToExp (Prod
x : SofP
xs) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (Prod -> Exp
prodToExp Prod
x) forall a b. (a -> b) -> a -> b
$
    forall a b. (a -> b) -> [a] -> [b]
map Prod -> Exp
prodToExp SofP
xs

prodToExp :: Prod -> Exp
prodToExp :: Prod -> Exp
prodToExp (Prod Bool
_ []) = Int64 -> Exp
val Int64
1
prodToExp (Prod Bool
True [ValueExp (IntValue (Int64Value Int64
i))]) = forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (-Int64
i)
prodToExp (Prod Bool
True [Exp]
as) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (Int64 -> Exp
val (-Int64
1)) [Exp]
as
prodToExp (Prod Bool
False (Exp
a : [Exp]
as)) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) Exp
a [Exp]
as

simplifySofP :: SofP -> SofP
simplifySofP :: SofP -> SofP
simplifySofP =
  -- TODO: Maybe 'constFoldValueExps' is not necessary after adding scaleConsts
  forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Prod -> Maybe Prod
applyZero forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> Prod
removeOnes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
scaleConsts forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
constFoldValueExps forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
removeNegations)

simplifySofP' :: SofP -> SofP
simplifySofP' :: SofP -> SofP
simplifySofP' = forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Prod -> Maybe Prod
applyZero forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> Prod
removeOnes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
scaleConsts forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
removeNegations)

simplify0 :: Exp -> SofP
simplify0 :: Exp -> SofP
simplify0 = SofP -> SofP
simplifySofP forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
sumOfProducts

simplify :: Exp -> Exp
simplify :: Exp -> Exp
simplify = forall v. PrimExp v -> PrimExp v
constFoldPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> Exp
sumToExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
simplify0

simplify' :: TExp -> TExp
simplify' :: TPrimExp Int64 VName -> TPrimExp Int64 VName
simplify' = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

applyZero :: Prod -> Maybe Prod
applyZero :: Prod -> Maybe Prod
applyZero p :: Prod
p@(Prod Bool
_ [Exp]
as)
  | Int64 -> Exp
val Int64
0 forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Exp]
as = forall a. Maybe a
Nothing
  | Bool
otherwise = forall a. a -> Maybe a
Just Prod
p

removeOnes :: Prod -> Prod
removeOnes :: Prod -> Prod
removeOnes (Prod Bool
neg [Exp]
as) =
  let as' :: [Exp]
as' = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= Int64 -> Exp
val Int64
1) [Exp]
as
   in Bool -> [Exp] -> Prod
Prod Bool
neg forall a b. (a -> b) -> a -> b
$ if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
as' then [forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1] else [Exp]
as'

removeNegations :: SofP -> SofP
removeNegations :: SofP -> SofP
removeNegations [] = []
removeNegations (Prod
t : SofP
ts) =
  case forall a. (a -> Bool) -> [a] -> ([a], [a])
break (forall a. Eq a => a -> a -> Bool
== Prod -> Prod
negate Prod
t) SofP
ts of
    (SofP
start, Prod
_ : SofP
rest) -> SofP -> SofP
removeNegations forall a b. (a -> b) -> a -> b
$ SofP
start forall a. Semigroup a => a -> a -> a
<> SofP
rest
    (SofP, SofP)
_ -> Prod
t forall a. a -> [a] -> [a]
: SofP -> SofP
removeNegations SofP
ts

constFoldValueExps :: SofP -> SofP
constFoldValueExps :: SofP -> SofP
constFoldValueExps SofP
prods =
  let (SofP
value_exps, SofP
others) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Exp -> Bool
isPrimValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> [Exp]
atoms) SofP
prods
      value_exps' :: SofP
value_exps' = Exp -> SofP
sumOfProducts forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimExp v
constFoldPrimExp forall a b. (a -> b) -> a -> b
$ SofP -> Exp
sumToExp SofP
value_exps
   in SofP
value_exps' forall a. Semigroup a => a -> a -> a
<> SofP
others

intFromExp :: Exp -> Maybe Int64
intFromExp :: Exp -> Maybe Int64
intFromExp (ValueExp (IntValue IntValue
x)) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntValue -> int
valueIntegral IntValue
x
intFromExp Exp
_ = forall a. Maybe a
Nothing

-- | Given @-[2, x]@ returns @(-2, [x])@
prodToScale :: Prod -> (Int64, [Exp])
prodToScale :: Prod -> (Int64, [Exp])
prodToScale (Prod Bool
b [Exp]
exps) =
  let ([Int64]
scalars, [Exp]
exps') = forall a b. (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe Exp -> Maybe Int64
intFromExp [Exp]
exps
   in if Bool
b
        then (-(forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
scalars), [Exp]
exps')
        else (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
scalars, [Exp]
exps')

-- | Given @(-2, [x])@ returns @-[1, 2, x]@
scaleToProd :: (Int64, [Exp]) -> Prod
scaleToProd :: (Int64, [Exp]) -> Prod
scaleToProd (Int64
i, [Exp]
exps) =
  Bool -> [Exp] -> Prod
Prod (Int64
i forall a. Ord a => a -> a -> Bool
< Int64
0) forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Int64
i) forall a. a -> [a] -> [a]
: [Exp]
exps

-- | Given @[[2, x], -[x]]@ returns @[[x]]@
scaleConsts :: SofP -> SofP
scaleConsts :: SofP -> SofP
scaleConsts =
  SofP -> [(Int64, [Exp])] -> SofP
helper [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Prod -> (Int64, [Exp])
prodToScale
  where
    helper :: [Prod] -> [(Int64, [Exp])] -> [Prod]
    helper :: SofP -> [(Int64, [Exp])] -> SofP
helper SofP
acc [] = forall a. [a] -> [a]
reverse SofP
acc
    helper SofP
acc ((Int64
scale, [Exp]
exps) : [(Int64, [Exp])]
rest) =
      case forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [(Int64, [Exp])]
rest forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (forall a. Eq a => a -> a -> Bool
(==) [Exp]
exps forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Int64, [Exp])]
rest of
        Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])])
Nothing -> SofP -> [(Int64, [Exp])] -> SofP
helper ((Int64, [Exp]) -> Prod
scaleToProd (Int64
scale, [Exp]
exps) forall a. a -> [a] -> [a]
: SofP
acc) [(Int64, [Exp])]
rest
        Just ([(Int64, [Exp])]
before, (Int64
scale', [Exp]
_), [(Int64, [Exp])]
after) ->
          SofP -> [(Int64, [Exp])] -> SofP
helper SofP
acc forall a b. (a -> b) -> a -> b
$ (Int64
scale forall a. Num a => a -> a -> a
+ Int64
scale', [Exp]
exps) forall a. a -> [a] -> [a]
: ([(Int64, [Exp])]
before forall a. Semigroup a => a -> a -> a
<> [(Int64, [Exp])]
after)

isPrimValue :: Exp -> Bool
isPrimValue :: Exp -> Bool
isPrimValue (ValueExp PrimValue
_) = Bool
True
isPrimValue Exp
_ = Bool
False

val :: Int64 -> Exp
val :: Int64 -> Exp
val = forall v. PrimValue -> PrimExp v
ValueExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IntValue
Int64Value

add :: SofP -> SofP -> SofP
add :: SofP -> SofP -> SofP
add SofP
ps1 SofP
ps2 = SofP -> SofP
simplifySofP forall a b. (a -> b) -> a -> b
$ SofP
ps1 forall a. Semigroup a => a -> a -> a
<> SofP
ps2

sub :: SofP -> SofP -> SofP
sub :: SofP -> SofP -> SofP
sub SofP
ps1 SofP
ps2 = SofP -> SofP -> SofP
add SofP
ps1 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate SofP
ps2

isMultipleOf :: Prod -> [Exp] -> Bool
isMultipleOf :: Prod -> [Exp] -> Bool
isMultipleOf (Prod Bool
_ [Exp]
as) [Exp]
term =
  let quotient :: [Exp]
quotient = [Exp]
as forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
term
   in forall a. Ord a => [a] -> [a]
sort ([Exp]
quotient forall a. Semigroup a => a -> a -> a
<> [Exp]
term) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort [Exp]
as

maybeDivide :: Prod -> Prod -> Maybe Prod
maybeDivide :: Prod -> Prod -> Maybe Prod
maybeDivide Prod
dividend Prod
divisor
  | Prod Bool
dividend_b [Exp]
dividend_factors <- Prod
dividend,
    Prod Bool
divisor_b [Exp]
divisor_factors <- Prod
divisor,
    [Exp]
quotient <- [Exp]
dividend_factors forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
divisor_factors,
    forall a. Ord a => [a] -> [a]
sort ([Exp]
quotient forall a. Semigroup a => a -> a -> a
<> [Exp]
divisor_factors) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort [Exp]
dividend_factors =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> [Exp] -> Prod
Prod (Bool
dividend_b forall a. Bits a => a -> a -> a
`xor` Bool
divisor_b) [Exp]
quotient
  | (Int64
dividend_scale, [Exp]
dividend_rest) <- Prod -> (Int64, [Exp])
prodToScale Prod
dividend,
    (Int64
divisor_scale, [Exp]
divisor_rest) <- Prod -> (Int64, [Exp])
prodToScale Prod
divisor,
    Int64
dividend_scale forall a. Integral a => a -> a -> a
`mod` Int64
divisor_scale forall a. Eq a => a -> a -> Bool
== Int64
0,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [Exp]
divisor_rest forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
dividend_rest =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
        Bool -> [Exp] -> Prod
Prod
          (forall a. Num a => a -> a
signum (Int64
dividend_scale forall a. Integral a => a -> a -> a
`div` Int64
divisor_scale) forall a. Ord a => a -> a -> Bool
< Int64
0)
          ( forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value forall a b. (a -> b) -> a -> b
$ Int64
dividend_scale forall a. Integral a => a -> a -> a
`div` Int64
divisor_scale)
              forall a. a -> [a] -> [a]
: ([Exp]
dividend_rest forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
divisor_rest)
          )
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Given a list of 'Names' that we know are non-negative (>= 0), determine
-- whether we can say for sure that the given 'AlgSimplify.SofP' is
-- non-negative. Conservatively returns 'False' if there is any doubt.
--
-- TODO: We need to expand this to be able to handle cases such as @i*n + g < (i
-- + 1) * n@, if it is known that @g < n@, eg. from a 'SegSpace' or a loop form.
nonNegativeish :: Names -> SofP -> Bool
nonNegativeish :: Names -> SofP -> Bool
nonNegativeish Names
non_negatives = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> Prod -> Bool
nonNegativeishProd Names
non_negatives)

nonNegativeishProd :: Names -> Prod -> Bool
nonNegativeishProd :: Names -> Prod -> Bool
nonNegativeishProd Names
_ (Prod Bool
True [Exp]
_) = Bool
False
nonNegativeishProd Names
non_negatives (Prod Bool
False [Exp]
as) =
  forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> Exp -> Bool
nonNegativeishExp Names
non_negatives) [Exp]
as

nonNegativeishExp :: Names -> PrimExp VName -> Bool
nonNegativeishExp :: Names -> Exp -> Bool
nonNegativeishExp Names
_ (ValueExp PrimValue
v) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
v
nonNegativeishExp Names
non_negatives (LeafExp VName
vname PrimType
_) = VName
vname VName -> Names -> Bool
`nameIn` Names
non_negatives
nonNegativeishExp Names
_ Exp
_ = Bool
False

-- | Is e1 symbolically less than or equal to e2?
lessThanOrEqualish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish :: [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish [(VName, Exp)]
less_thans0 Names
non_negatives TPrimExp Int64 VName
e1 TPrimExp Int64 VName
e2 =
  case TPrimExp Int64 VName
e2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
e1 forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. a -> (a -> b) -> b
& Exp -> SofP
simplify0 of
    [] -> Bool
True
    SofP
simplified ->
      Names -> SofP -> Bool
nonNegativeish Names
non_negatives forall a b. (a -> b) -> a -> b
$
        forall a. Eq a => (a -> a) -> a -> a
fixPoint (SofP -> [(SubExp, Exp)] -> SofP
`removeLessThans` [(SubExp, Exp)]
less_thans) SofP
simplified
  where
    less_thans :: [(SubExp, Exp)]
less_thans =
      forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
        (\(VName
i, Exp
bound) -> [(VName -> SubExp
Var VName
i, Exp
bound), (PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0, Exp
bound)])
        [(VName, Exp)]
less_thans0

lessThanish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanish :: [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanish [(VName, Exp)]
less_thans Names
non_negatives TPrimExp Int64 VName
e1 =
  [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish [(VName, Exp)]
less_thans Names
non_negatives (TPrimExp Int64 VName
e1 forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)

removeLessThans :: SofP -> [(SubExp, PrimExp VName)] -> SofP
removeLessThans :: SofP -> [(SubExp, Exp)] -> SofP
removeLessThans =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
    ( \SofP
sofp (SubExp
i, Exp
bound) ->
        let to_remove :: SofP
to_remove =
              SofP -> SofP
simplifySofP forall a b. (a -> b) -> a -> b
$
                Bool -> [Exp] -> Prod
Prod Bool
True [PrimType -> SubExp -> Exp
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
i]
                  forall a. a -> [a] -> [a]
: Exp -> SofP
simplify0 Exp
bound
         in case SofP
to_remove forall a. Eq a => [a] -> [a] -> [a]
`intersect` SofP
sofp of
              SofP
to_remove' | SofP
to_remove' forall a. Eq a => a -> a -> Bool
== SofP
to_remove -> SofP
sofp forall a. Eq a => [a] -> [a] -> [a]
\\ SofP
to_remove
              SofP
_ -> SofP
sofp
    )

compareComplexity :: SofP -> SofP -> Ordering
compareComplexity :: SofP -> SofP -> Ordering
compareComplexity SofP
xs0 SofP
ys0 =
  case forall (t :: * -> *) a. Foldable t => t a -> Int
length SofP
xs0 forall a. Ord a => a -> a -> Ordering
`compare` forall (t :: * -> *) a. Foldable t => t a -> Int
length SofP
ys0 of
    Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs0 SofP
ys0
    Ordering
c -> Ordering
c
  where
    helper :: SofP -> SofP -> Ordering
helper [] [] = Ordering
EQ
    helper [] SofP
_ = Ordering
LT
    helper SofP
_ [] = Ordering
GT
    helper (Prod
px : SofP
xs) (Prod
py : SofP
ys) =
      case (Prod -> (Int64, [Exp])
prodToScale Prod
px, Prod -> (Int64, [Exp])
prodToScale Prod
py) of
        ((Int64
ix, []), (Int64
iy, [])) -> case Int64
ix forall a. Ord a => a -> a -> Ordering
`compare` Int64
iy of
          Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs SofP
ys
          Ordering
c -> Ordering
c
        ((Int64
_, []), (Int64
_, [Exp]
_)) -> Ordering
LT
        ((Int64
_, [Exp]
_), (Int64
_, [])) -> Ordering
GT
        ((Int64
_, [Exp]
x), (Int64
_, [Exp]
y)) -> case forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
x forall a. Ord a => a -> a -> Ordering
`compare` forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
y of
          Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs SofP
ys
          Ordering
c -> Ordering
c