module Futhark.Analysis.AlgSimplify
  ( Prod (..),
    SofP,
    simplify0,
    simplify,
    simplify',
    simplifySofP,
    simplifySofP',
    sumOfProducts,
    sumToExp,
    prodToExp,
    add,
    sub,
    negate,
    isMultipleOf,
    maybeDivide,
    removeLessThans,
    lessThanish,
    compareComplexity,
  )
where

import Data.Bits (xor)
import Data.Function ((&))
import Data.List (findIndex, intersect, partition, sort, (\\))
import Data.Maybe (mapMaybe)
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop.Names
import Futhark.IR.Syntax.Core (SubExp (..), VName)
import Futhark.Util
import Futhark.Util.Pretty
import Prelude hiding (negate)

type Exp = PrimExp VName

type TExp = TPrimExp Int64 VName

data Prod = Prod
  { Prod -> Bool
negated :: Bool,
    Prod -> [Exp]
atoms :: [Exp]
  }
  deriving (Int -> Prod -> ShowS
SofP -> ShowS
Prod -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: SofP -> ShowS
$cshowList :: SofP -> ShowS
show :: Prod -> String
$cshow :: Prod -> String
showsPrec :: Int -> Prod -> ShowS
$cshowsPrec :: Int -> Prod -> ShowS
Show, Prod -> Prod -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Prod -> Prod -> Bool
$c/= :: Prod -> Prod -> Bool
== :: Prod -> Prod -> Bool
$c== :: Prod -> Prod -> Bool
Eq, Eq Prod
Prod -> Prod -> Bool
Prod -> Prod -> Ordering
Prod -> Prod -> Prod
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Prod -> Prod -> Prod
$cmin :: Prod -> Prod -> Prod
max :: Prod -> Prod -> Prod
$cmax :: Prod -> Prod -> Prod
>= :: Prod -> Prod -> Bool
$c>= :: Prod -> Prod -> Bool
> :: Prod -> Prod -> Bool
$c> :: Prod -> Prod -> Bool
<= :: Prod -> Prod -> Bool
$c<= :: Prod -> Prod -> Bool
< :: Prod -> Prod -> Bool
$c< :: Prod -> Prod -> Bool
compare :: Prod -> Prod -> Ordering
$ccompare :: Prod -> Prod -> Ordering
Ord)

type SofP = [Prod]

sumOfProducts :: Exp -> SofP
sumOfProducts :: Exp -> SofP
sumOfProducts = forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
sortProduct forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
sumOfProducts'

sortProduct :: Prod -> Prod
sortProduct :: Prod -> Prod
sortProduct (Prod Bool
n [Exp]
as) = Bool -> [Exp] -> Prod
Prod Bool
n forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> [a]
sort [Exp]
as

sumOfProducts' :: Exp -> SofP
sumOfProducts' :: Exp -> SofP
sumOfProducts' (BinOpExp (Add IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> SofP
sumOfProducts' Exp
e2
sumOfProducts' (BinOpExp (Sub IntType
Int64 Overflow
_) (ValueExp (IntValue (Int64Value Int64
0))) Exp
e) =
  forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate forall a b. (a -> b) -> a -> b
$ Exp -> SofP
sumOfProducts' Exp
e
sumOfProducts' (BinOpExp (Sub IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate (Exp -> SofP
sumOfProducts' Exp
e2)
sumOfProducts' (BinOpExp (Mul IntType
Int64 Overflow
_) Exp
e1 Exp
e2) =
  Exp -> SofP
sumOfProducts' Exp
e1 SofP -> SofP -> SofP
`mult` Exp -> SofP
sumOfProducts' Exp
e2
sumOfProducts' (ValueExp (IntValue (Int64Value Int64
i))) =
  [Bool -> [Exp] -> Prod
Prod (Int64
i forall a. Ord a => a -> a -> Bool
< Int64
0) [forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Int64
i]]
sumOfProducts' Exp
e = [Bool -> [Exp] -> Prod
Prod Bool
False [Exp
e]]

mult :: SofP -> SofP -> SofP
mult :: SofP -> SofP -> SofP
mult SofP
xs SofP
ys = [Bool -> [Exp] -> Prod
Prod (Bool
b forall a. Bits a => a -> a -> a
`xor` Bool
b') ([Exp]
x forall a. Semigroup a => a -> a -> a
<> [Exp]
y) | Prod Bool
b [Exp]
x <- SofP
xs, Prod Bool
b' [Exp]
y <- SofP
ys]

negate :: Prod -> Prod
negate :: Prod -> Prod
negate Prod
p = Prod
p {negated :: Bool
negated = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Prod -> Bool
negated Prod
p}

sumToExp :: SofP -> Exp
sumToExp :: SofP -> Exp
sumToExp [] = Int64 -> Exp
val Int64
0
sumToExp [Prod
x] = Prod -> Exp
prodToExp Prod
x
sumToExp (Prod
x : SofP
xs) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (Prod -> Exp
prodToExp Prod
x) forall a b. (a -> b) -> a -> b
$
    forall a b. (a -> b) -> [a] -> [b]
map Prod -> Exp
prodToExp SofP
xs

prodToExp :: Prod -> Exp
prodToExp :: Prod -> Exp
prodToExp (Prod Bool
_ []) = Int64 -> Exp
val Int64
1
prodToExp (Prod Bool
True [ValueExp (IntValue (Int64Value Int64
i))]) = forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value (-Int64
i)
prodToExp (Prod Bool
True [Exp]
as) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (Int64 -> Exp
val (-Int64
1)) [Exp]
as
prodToExp (Prod Bool
False (Exp
a : [Exp]
as)) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) Exp
a [Exp]
as

simplifySofP :: SofP -> SofP
simplifySofP :: SofP -> SofP
simplifySofP =
  -- TODO: Maybe 'constFoldValueExps' is not necessary after adding scaleConsts
  forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Prod -> Maybe Prod
applyZero forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> Prod
removeOnes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
scaleConsts forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
constFoldValueExps forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
removeNegations)

simplifySofP' :: SofP -> SofP
simplifySofP' :: SofP -> SofP
simplifySofP' = forall a. Eq a => (a -> a) -> a -> a
fixPoint (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Prod -> Maybe Prod
applyZero forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> Prod
removeOnes) forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
scaleConsts forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> SofP
removeNegations)

simplify0 :: Exp -> SofP
simplify0 :: Exp -> SofP
simplify0 = SofP -> SofP
simplifySofP forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
sumOfProducts

simplify :: Exp -> Exp
simplify :: Exp -> Exp
simplify = forall v. PrimExp v -> PrimExp v
constFoldPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SofP -> Exp
sumToExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> SofP
simplify0

simplify' :: TExp -> TExp
simplify' :: TPrimExp Int64 VName -> TPrimExp Int64 VName
simplify' = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

applyZero :: Prod -> Maybe Prod
applyZero :: Prod -> Maybe Prod
applyZero p :: Prod
p@(Prod Bool
_ [Exp]
as)
  | Int64 -> Exp
val Int64
0 forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Exp]
as = forall a. Maybe a
Nothing
  | Bool
otherwise = forall a. a -> Maybe a
Just Prod
p

removeOnes :: Prod -> Prod
removeOnes :: Prod -> Prod
removeOnes (Prod Bool
neg [Exp]
as) =
  let as' :: [Exp]
as' = forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Eq a => a -> a -> Bool
/= Int64 -> Exp
val Int64
1) [Exp]
as
   in Bool -> [Exp] -> Prod
Prod Bool
neg forall a b. (a -> b) -> a -> b
$ if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
as' then [forall v. PrimValue -> PrimExp v
ValueExp forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
1] else [Exp]
as'

removeNegations :: SofP -> SofP
removeNegations :: SofP -> SofP
removeNegations [] = []
removeNegations (Prod
t : SofP
ts) =
  case forall a. (a -> Bool) -> [a] -> ([a], [a])
break (forall a. Eq a => a -> a -> Bool
== Prod -> Prod
negate Prod
t) SofP
ts of
    (SofP
start, Prod
_ : SofP
rest) -> SofP -> SofP
removeNegations forall a b. (a -> b) -> a -> b
$ SofP
start forall a. Semigroup a => a -> a -> a
<> SofP
rest
    (SofP, SofP)
_ -> Prod
t forall a. a -> [a] -> [a]
: SofP -> SofP
removeNegations SofP
ts

constFoldValueExps :: SofP -> SofP
constFoldValueExps :: SofP -> SofP
constFoldValueExps SofP
prods =
  let (SofP
value_exps, SofP
others) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Exp -> Bool
isPrimValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> [Exp]
atoms) SofP
prods
      value_exps' :: SofP
value_exps' = Exp -> SofP
sumOfProducts forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimExp v
constFoldPrimExp forall a b. (a -> b) -> a -> b
$ SofP -> Exp
sumToExp SofP
value_exps
   in SofP
value_exps' forall a. Semigroup a => a -> a -> a
<> SofP
others

intFromExp :: Exp -> Maybe Int64
intFromExp :: Exp -> Maybe Int64
intFromExp (ValueExp (IntValue IntValue
x)) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall int. Integral int => IntValue -> int
valueIntegral IntValue
x
intFromExp Exp
_ = forall a. Maybe a
Nothing

-- | Given @-[2, x]@ returns @(-2, [x])@
prodToScale :: Prod -> (Int64, [Exp])
prodToScale :: Prod -> (Int64, [Exp])
prodToScale (Prod Bool
b [Exp]
exps) =
  let ([Int64]
scalars, [Exp]
exps') = forall a b. (a -> Maybe b) -> [a] -> ([b], [a])
partitionMaybe Exp -> Maybe Int64
intFromExp [Exp]
exps
   in if Bool
b
        then (-(forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
scalars), [Exp]
exps')
        else (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int64]
scalars, [Exp]
exps')

-- | Given @(-2, [x])@ returns @-[1, 2, x]@
scaleToProd :: (Int64, [Exp]) -> Prod
scaleToProd :: (Int64, [Exp]) -> Prod
scaleToProd (Int64
i, [Exp]
exps) =
  Bool -> [Exp] -> Prod
Prod (Int64
i forall a. Ord a => a -> a -> Bool
< Int64
0) forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value forall a b. (a -> b) -> a -> b
$ forall a. Num a => a -> a
abs Int64
i) forall a. a -> [a] -> [a]
: [Exp]
exps

-- | Given @[[2, x], -[x]]@ returns @[[x]]@
scaleConsts :: SofP -> SofP
scaleConsts :: SofP -> SofP
scaleConsts =
  SofP -> [(Int64, [Exp])] -> SofP
helper [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map Prod -> (Int64, [Exp])
prodToScale
  where
    helper :: [Prod] -> [(Int64, [Exp])] -> [Prod]
    helper :: SofP -> [(Int64, [Exp])] -> SofP
helper SofP
acc [] = forall a. [a] -> [a]
reverse SofP
acc
    helper SofP
acc ((Int64
scale, [Exp]
exps) : [(Int64, [Exp])]
rest) =
      case forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [(Int64, [Exp])]
rest forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (forall a. Eq a => a -> a -> Bool
(==) [Exp]
exps forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Int64, [Exp])]
rest of
        Maybe ([(Int64, [Exp])], (Int64, [Exp]), [(Int64, [Exp])])
Nothing -> SofP -> [(Int64, [Exp])] -> SofP
helper ((Int64, [Exp]) -> Prod
scaleToProd (Int64
scale, [Exp]
exps) forall a. a -> [a] -> [a]
: SofP
acc) [(Int64, [Exp])]
rest
        Just ([(Int64, [Exp])]
before, (Int64
scale', [Exp]
_), [(Int64, [Exp])]
after) ->
          SofP -> [(Int64, [Exp])] -> SofP
helper SofP
acc forall a b. (a -> b) -> a -> b
$ (Int64
scale forall a. Num a => a -> a -> a
+ Int64
scale', [Exp]
exps) forall a. a -> [a] -> [a]
: ([(Int64, [Exp])]
before forall a. Semigroup a => a -> a -> a
<> [(Int64, [Exp])]
after)

isPrimValue :: Exp -> Bool
isPrimValue :: Exp -> Bool
isPrimValue (ValueExp PrimValue
_) = Bool
True
isPrimValue Exp
_ = Bool
False

val :: Int64 -> Exp
val :: Int64 -> Exp
val = forall v. PrimValue -> PrimExp v
ValueExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IntValue
Int64Value

add :: SofP -> SofP -> SofP
add :: SofP -> SofP -> SofP
add SofP
ps1 SofP
ps2 = SofP -> SofP
simplifySofP forall a b. (a -> b) -> a -> b
$ SofP
ps1 forall a. Semigroup a => a -> a -> a
<> SofP
ps2

sub :: SofP -> SofP -> SofP
sub :: SofP -> SofP -> SofP
sub SofP
ps1 SofP
ps2 = SofP -> SofP -> SofP
add SofP
ps1 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
negate SofP
ps2

isMultipleOf :: Prod -> [Exp] -> Bool
isMultipleOf :: Prod -> [Exp] -> Bool
isMultipleOf (Prod Bool
_ [Exp]
as) [Exp]
term =
  let quotient :: [Exp]
quotient = [Exp]
as forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
term
   in forall a. Ord a => [a] -> [a]
sort ([Exp]
quotient forall a. Semigroup a => a -> a -> a
<> [Exp]
term) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort [Exp]
as

maybeDivide :: Prod -> Prod -> Maybe Prod
maybeDivide :: Prod -> Prod -> Maybe Prod
maybeDivide Prod
dividend Prod
divisor
  | Prod Bool
dividend_b [Exp]
dividend_factors <- Prod
dividend,
    Prod Bool
divisor_b [Exp]
divisor_factors <- Prod
divisor,
    [Exp]
quotient <- [Exp]
dividend_factors forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
divisor_factors,
    forall a. Ord a => [a] -> [a]
sort ([Exp]
quotient forall a. Semigroup a => a -> a -> a
<> [Exp]
divisor_factors) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort [Exp]
dividend_factors =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> [Exp] -> Prod
Prod (Bool
dividend_b forall a. Bits a => a -> a -> a
`xor` Bool
divisor_b) [Exp]
quotient
  | (Int64
dividend_scale, [Exp]
dividend_rest) <- Prod -> (Int64, [Exp])
prodToScale Prod
dividend,
    (Int64
divisor_scale, [Exp]
divisor_rest) <- Prod -> (Int64, [Exp])
prodToScale Prod
divisor,
    Int64
dividend_scale forall a. Integral a => a -> a -> a
`mod` Int64
divisor_scale forall a. Eq a => a -> a -> Bool
== Int64
0,
    forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ [Exp]
divisor_rest forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
dividend_rest =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
        Bool -> [Exp] -> Prod
Prod
          (forall a. Num a => a -> a
signum (Int64
dividend_scale forall a. Integral a => a -> a -> a
`div` Int64
divisor_scale) forall a. Ord a => a -> a -> Bool
< Int64
0)
          ( forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value forall a b. (a -> b) -> a -> b
$ Int64
dividend_scale forall a. Integral a => a -> a -> a
`div` Int64
divisor_scale)
              forall a. a -> [a] -> [a]
: ([Exp]
dividend_rest forall a. Eq a => [a] -> [a] -> [a]
\\ [Exp]
divisor_rest)
          )
  | Bool
otherwise = forall a. Maybe a
Nothing

-- | Given a list of 'Names' that we know are non-negative (>= 0), determine
-- whether we can say for sure that the given 'AlgSimplify.SofP' is
-- non-negative. Conservatively returns 'False' if there is any doubt.
--
-- TODO: We need to expand this to be able to handle cases such as @i*n + g < (i
-- + 1) * n@, if it is known that @g < n@, eg. from a 'SegSpace' or a loop form.
nonNegativeish :: Names -> SofP -> Bool
nonNegativeish :: Names -> SofP -> Bool
nonNegativeish Names
non_negatives = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> Prod -> Bool
nonNegativeishProd Names
non_negatives)

nonNegativeishProd :: Names -> Prod -> Bool
nonNegativeishProd :: Names -> Prod -> Bool
nonNegativeishProd Names
_ (Prod Bool
True [Exp]
_) = Bool
False
nonNegativeishProd Names
non_negatives (Prod Bool
False [Exp]
as) =
  forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Names -> Exp -> Bool
nonNegativeishExp Names
non_negatives) [Exp]
as

nonNegativeishExp :: Names -> PrimExp VName -> Bool
nonNegativeishExp :: Names -> Exp -> Bool
nonNegativeishExp Names
_ (ValueExp PrimValue
v) = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ PrimValue -> Bool
negativeIsh PrimValue
v
nonNegativeishExp Names
non_negatives (LeafExp VName
vname PrimType
_) = VName
vname VName -> Names -> Bool
`nameIn` Names
non_negatives
nonNegativeishExp Names
_ Exp
_ = Bool
False

-- | Is e1 symbolically less than or equal to e2?
lessThanOrEqualish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish :: [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish [(VName, Exp)]
less_thans0 Names
non_negatives TPrimExp Int64 VName
e1 TPrimExp Int64 VName
e2 =
  case TPrimExp Int64 VName
e2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
e1 forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. a -> (a -> b) -> b
& Exp -> SofP
simplify0 of
    [] -> Bool
True
    SofP
simplified ->
      Names -> SofP -> Bool
nonNegativeish Names
non_negatives forall a b. (a -> b) -> a -> b
$
        forall a. Eq a => (a -> a) -> a -> a
fixPoint (SofP -> [(SubExp, Exp)] -> SofP
`removeLessThans` [(SubExp, Exp)]
less_thans) SofP
simplified
  where
    less_thans :: [(SubExp, Exp)]
less_thans =
      forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
        (\(VName
i, Exp
bound) -> [(VName -> SubExp
Var VName
i, Exp
bound), (PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue forall a b. (a -> b) -> a -> b
$ Int64 -> IntValue
Int64Value Int64
0, Exp
bound)])
        [(VName, Exp)]
less_thans0

lessThanish :: [(VName, PrimExp VName)] -> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanish :: [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanish [(VName, Exp)]
less_thans Names
non_negatives TPrimExp Int64 VName
e1 =
  [(VName, Exp)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
lessThanOrEqualish [(VName, Exp)]
less_thans Names
non_negatives (TPrimExp Int64 VName
e1 forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)

removeLessThans :: SofP -> [(SubExp, PrimExp VName)] -> SofP
removeLessThans :: SofP -> [(SubExp, Exp)] -> SofP
removeLessThans =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
    ( \SofP
sofp (SubExp
i, Exp
bound) ->
        let to_remove :: SofP
to_remove =
              SofP -> SofP
simplifySofP forall a b. (a -> b) -> a -> b
$
                Bool -> [Exp] -> Prod
Prod Bool
True [PrimType -> SubExp -> Exp
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int64) SubExp
i]
                  forall a. a -> [a] -> [a]
: Exp -> SofP
simplify0 Exp
bound
         in case SofP
to_remove forall a. Eq a => [a] -> [a] -> [a]
`intersect` SofP
sofp of
              SofP
to_remove' | SofP
to_remove' forall a. Eq a => a -> a -> Bool
== SofP
to_remove -> SofP
sofp forall a. Eq a => [a] -> [a] -> [a]
\\ SofP
to_remove
              SofP
_ -> SofP
sofp
    )

compareComplexity :: SofP -> SofP -> Ordering
compareComplexity :: SofP -> SofP -> Ordering
compareComplexity SofP
xs0 SofP
ys0 =
  case forall (t :: * -> *) a. Foldable t => t a -> Int
length SofP
xs0 forall a. Ord a => a -> a -> Ordering
`compare` forall (t :: * -> *) a. Foldable t => t a -> Int
length SofP
ys0 of
    Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs0 SofP
ys0
    Ordering
c -> Ordering
c
  where
    helper :: SofP -> SofP -> Ordering
helper [] [] = Ordering
EQ
    helper [] SofP
_ = Ordering
LT
    helper SofP
_ [] = Ordering
GT
    helper (Prod
px : SofP
xs) (Prod
py : SofP
ys) =
      case (Prod -> (Int64, [Exp])
prodToScale Prod
px, Prod -> (Int64, [Exp])
prodToScale Prod
py) of
        ((Int64
ix, []), (Int64
iy, [])) -> case Int64
ix forall a. Ord a => a -> a -> Ordering
`compare` Int64
iy of
          Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs SofP
ys
          Ordering
c -> Ordering
c
        ((Int64
_, []), (Int64
_, [Exp]
_)) -> Ordering
LT
        ((Int64
_, [Exp]
_), (Int64
_, [])) -> Ordering
GT
        ((Int64
_, [Exp]
x), (Int64
_, [Exp]
y)) -> case forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
x forall a. Ord a => a -> a -> Ordering
`compare` forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
y of
          Ordering
EQ -> SofP -> SofP -> Ordering
helper SofP
xs SofP
ys
          Ordering
c -> Ordering
c