{-# LANGUAGE OverloadedStrings #-}

module Futhark.IR.Mem.Interval
  ( Interval (..),
    distributeOffset,
    expandOffset,
    intervalOverlap,
    selfOverlap,
    primBool,
    intervalPairs,
    justLeafExp,
  )
where

import Data.Function (on)
import Data.List (maximumBy, minimumBy, (\\))
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop
import Futhark.IR.Syntax hiding (Result)
import Futhark.Util

data Interval = Interval
  { Interval -> TPrimExp Int64 VName
lowerBound :: TPrimExp Int64 VName,
    Interval -> TPrimExp Int64 VName
numElements :: TPrimExp Int64 VName,
    Interval -> TPrimExp Int64 VName
stride :: TPrimExp Int64 VName
  }
  deriving (Int -> Interval -> ShowS
[Interval] -> ShowS
Interval -> String
(Int -> Interval -> ShowS)
-> (Interval -> String) -> ([Interval] -> ShowS) -> Show Interval
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Interval -> ShowS
showsPrec :: Int -> Interval -> ShowS
$cshow :: Interval -> String
show :: Interval -> String
$cshowList :: [Interval] -> ShowS
showList :: [Interval] -> ShowS
Show, Interval -> Interval -> Bool
(Interval -> Interval -> Bool)
-> (Interval -> Interval -> Bool) -> Eq Interval
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Interval -> Interval -> Bool
== :: Interval -> Interval -> Bool
$c/= :: Interval -> Interval -> Bool
/= :: Interval -> Interval -> Bool
Eq)

instance FreeIn Interval where
  freeIn' :: Interval -> FV
freeIn' (Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st) = TPrimExp Int64 VName -> FV
forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
lb FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> TPrimExp Int64 VName -> FV
forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
ne FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> TPrimExp Int64 VName -> FV
forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
st

distributeOffset :: (MonadFail m) => AlgSimplify.SofP -> [Interval] -> m [Interval]
distributeOffset :: forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset [] [Interval]
interval = [Interval] -> m [Interval]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Interval]
interval
distributeOffset [Prod]
offset [] = String -> m [Interval]
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> m [Interval]) -> String -> m [Interval]
forall a b. (a -> b) -> a -> b
$ String
"Cannot distribute offset " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> [Prod] -> String
forall a. Show a => a -> String
show [Prod]
offset String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" across empty interval"
distributeOffset [Prod]
offset [Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
1] = [Interval] -> m [Interval]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod]
offset)) TPrimExp Int64 VName
ne TPrimExp Int64 VName
1]
distributeOffset [Prod]
offset (Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 : [Interval]
is)
  | Prod
st <- Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
False [TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
st0],
    Just ([Prod]
before, Prod
quotient, [Prod]
after) <- (Prod -> Maybe Prod) -> [Prod] -> Maybe ([Prod], Prod, [Prod])
forall a b. (a -> Maybe b) -> [a] -> Maybe ([a], b, [a])
focusMaybe (Prod -> Prod -> Maybe Prod
`AlgSimplify.maybeDivide` Prod
st) [Prod]
offset =
      [Prod] -> [Interval] -> m [Interval]
forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset ([Prod]
before [Prod] -> [Prod] -> [Prod]
forall a. Semigroup a => a -> a -> a
<> [Prod]
after) ([Interval] -> m [Interval]) -> [Interval] -> m [Interval]
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod
quotient])) TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
is
  | [Prod
st] <- PrimExp VName -> [Prod]
AlgSimplify.simplify0 (PrimExp VName -> [Prod]) -> PrimExp VName -> [Prod]
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
st0,
    Just ([Prod]
before, Prod
quotient, [Prod]
after) <- (Prod -> Maybe Prod) -> [Prod] -> Maybe ([Prod], Prod, [Prod])
forall a b. (a -> Maybe b) -> [a] -> Maybe ([a], b, [a])
focusMaybe (Prod -> Prod -> Maybe Prod
`AlgSimplify.maybeDivide` Prod
st) [Prod]
offset =
      [Prod] -> [Interval] -> m [Interval]
forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset ([Prod]
before [Prod] -> [Prod] -> [Prod]
forall a. Semigroup a => a -> a -> a
<> [Prod]
after) ([Interval] -> m [Interval]) -> [Interval] -> m [Interval]
forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod
quotient])) TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
is
  | Bool
otherwise = do
      [Interval]
rest <- [Prod] -> [Interval] -> m [Interval]
forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset [Prod]
offset [Interval]
is
      [Interval] -> m [Interval]
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Interval] -> m [Interval]) -> [Interval] -> m [Interval]
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
rest

findMostComplexTerm :: AlgSimplify.SofP -> (AlgSimplify.Prod, AlgSimplify.SofP)
findMostComplexTerm :: [Prod] -> (Prod, [Prod])
findMostComplexTerm [Prod]
prods =
  let max_prod :: Prod
max_prod = (Prod -> Prod -> Ordering) -> [Prod] -> Prod
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Prod -> Int) -> Prod -> Prod -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ([PrimExp VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName] -> Int)
-> (Prod -> [PrimExp VName]) -> Prod -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> [PrimExp VName]
AlgSimplify.atoms)) [Prod]
prods
   in (Prod
max_prod, [Prod]
prods [Prod] -> [Prod] -> [Prod]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Prod
max_prod])

findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride [PrimExp VName]
offset_term [Interval]
is =
  let strides :: [PrimExp VName]
strides = (Interval -> PrimExp VName) -> [Interval] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> (Interval -> TPrimExp Int64 VName) -> Interval -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Interval -> TPrimExp Int64 VName
stride) [Interval]
is
      p :: PrimExp VName
p =
        (PrimExp VName -> PrimExp VName -> Ordering)
-> [PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy
          ( Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare
              (Int -> Int -> Ordering)
-> (PrimExp VName -> Int)
-> PrimExp VName
-> PrimExp VName
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ( Prod -> Int
termDifferenceLength
                       (Prod -> Int) -> (PrimExp VName -> Prod) -> PrimExp VName -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Prod -> Prod -> Ordering) -> [Prod] -> Prod
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Prod -> Int) -> Prod -> Prod -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` \Prod
s -> [PrimExp VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ Prod -> [PrimExp VName]
AlgSimplify.atoms Prod
s))
                       ([Prod] -> Prod)
-> (PrimExp VName -> [Prod]) -> PrimExp VName -> Prod
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> [Prod]
AlgSimplify.simplify0
                   )
          )
          [PrimExp VName]
strides
   in ( PrimExp VName
p,
        ([PrimExp VName]
offset_term \\) ([PrimExp VName] -> [PrimExp VName])
-> [PrimExp VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
          Prod -> [PrimExp VName]
AlgSimplify.atoms (Prod -> [PrimExp VName]) -> Prod -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
            (Prod -> Prod -> Ordering) -> [Prod] -> Prod
forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> (Prod -> Int) -> Prod -> Prod -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` \Prod
s -> [PrimExp VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ Prod -> [PrimExp VName]
AlgSimplify.atoms Prod
s)) ([Prod] -> Prod) -> [Prod] -> Prod
forall a b. (a -> b) -> a -> b
$
              PrimExp VName -> [Prod]
AlgSimplify.simplify0 PrimExp VName
p
      )
  where
    termDifferenceLength :: Prod -> Int
termDifferenceLength (AlgSimplify.Prod Bool
_ [PrimExp VName]
xs) = [PrimExp VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ [PrimExp VName]
xs)

expandOffset :: AlgSimplify.SofP -> [Interval] -> Maybe AlgSimplify.SofP
expandOffset :: [Prod] -> [Interval] -> Maybe [Prod]
expandOffset [] [Interval]
_ = Maybe [Prod]
forall a. Maybe a
Nothing
expandOffset [Prod]
offset [Interval]
i1
  | (AlgSimplify.Prod Bool
b [PrimExp VName]
term_to_add, [Prod]
offset_rest) <- [Prod] -> (Prod, [Prod])
findMostComplexTerm [Prod]
offset, -- Find gnb
    (PrimExp VName
closest_stride, [PrimExp VName]
first_term_divisor) <- [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride [PrimExp VName]
term_to_add [Interval]
i1, -- find (nb-b, g)
    [Prod]
target <- [Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
b ([PrimExp VName] -> Prod) -> [PrimExp VName] -> Prod
forall a b. (a -> b) -> a -> b
$ PrimExp VName
closest_stride PrimExp VName -> [PrimExp VName] -> [PrimExp VName]
forall a. a -> [a] -> [a]
: [PrimExp VName]
first_term_divisor], -- g(nb-b)
    [Prod]
diff <- PrimExp VName -> [Prod]
AlgSimplify.sumOfProducts (PrimExp VName -> [Prod]) -> PrimExp VName -> [Prod]
forall a b. (a -> b) -> a -> b
$ [Prod] -> PrimExp VName
AlgSimplify.sumToExp ([Prod] -> PrimExp VName) -> [Prod] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
b [PrimExp VName]
term_to_add Prod -> [Prod] -> [Prod]
forall a. a -> [a] -> [a]
: (Prod -> Prod) -> [Prod] -> [Prod]
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate [Prod]
target, -- gnb - gnb + gb = gnb - g(nb-b)
    [Prod]
replacement <- [Prod]
target [Prod] -> [Prod] -> [Prod]
forall a. Semigroup a => a -> a -> a
<> [Prod]
diff -- gnb = g(nb-b) + gnb - gnb + gb
    =
      [Prod] -> Maybe [Prod]
forall a. a -> Maybe a
Just ([Prod]
replacement [Prod] -> [Prod] -> [Prod]
forall a. Semigroup a => a -> a -> a
<> [Prod]
offset_rest)

intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives (Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
ne1 TPrimExp Int64 VName
st1) (Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
ne2 TPrimExp Int64 VName
st2)
  | TPrimExp Int64 VName
st1 TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
lb2,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ne1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
lb2 =
      Bool
False
  | TPrimExp Int64 VName
st1 TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
lb1,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ne2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
lb1 =
      Bool
False
  | Bool
otherwise = Bool
True

primBool :: TPrimExp Bool VName -> Maybe Bool
primBool :: TPrimExp Bool VName -> Maybe Bool
primBool TPrimExp Bool VName
p
  | Just (BoolValue Bool
b) <- (VName -> Maybe PrimValue) -> PrimExp VName -> Maybe PrimValue
forall v (m :: * -> *).
(Pretty v, MonadFail m) =>
(v -> m PrimValue) -> PrimExp v -> m PrimValue
evalPrimExp (Maybe PrimValue -> VName -> Maybe PrimValue
forall a b. a -> b -> a
const Maybe PrimValue
forall a. Maybe a
Nothing) (PrimExp VName -> Maybe PrimValue)
-> PrimExp VName -> Maybe PrimValue
forall a b. (a -> b) -> a -> b
$ TPrimExp Bool VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Bool VName
p = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
b
  | Bool
otherwise = Maybe Bool
forall a. Maybe a
Nothing

intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' []
  where
    intervalPairs' :: [(Interval, Interval)] -> [Interval] -> [Interval] -> [(Interval, Interval)]
    intervalPairs' :: [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' [(Interval, Interval)]
acc [] [] = [(Interval, Interval)] -> [(Interval, Interval)]
forall a. [a] -> [a]
reverse [(Interval, Interval)]
acc
    intervalPairs' [(Interval, Interval)]
acc (i :: Interval
i@(Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
_ TPrimExp Int64 VName
st) : [Interval]
is) [] = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
1 TPrimExp Int64 VName
st) (Interval, Interval)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is []
    intervalPairs' [(Interval, Interval)]
acc [] (i :: Interval
i@(Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
_ TPrimExp Int64 VName
st) : [Interval]
is) = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
1 TPrimExp Int64 VName
st, Interval
i) (Interval, Interval)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [] [Interval]
is
    intervalPairs' [(Interval, Interval)]
acc (i1 :: Interval
i1@(Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
_ TPrimExp Int64 VName
st1) : [Interval]
is1) (i2 :: Interval
i2@(Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
_ TPrimExp Int64 VName
st2) : [Interval]
is2)
      | TPrimExp Int64 VName
st1 TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i1, Interval
i2) (Interval, Interval)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is1 [Interval]
is2
      | Bool
otherwise =
          let res1 :: [(Interval, Interval)]
res1 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i1, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
1 TPrimExp Int64 VName
st1) (Interval, Interval)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is1 (Interval
i2 Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
is2)
              res2 :: [(Interval, Interval)]
res2 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
1 TPrimExp Int64 VName
st2, Interval
i2) (Interval, Interval)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) (Interval
i1 Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
is1) [Interval]
is2
           in if [(Interval, Interval)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Interval, Interval)]
res1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [(Interval, Interval)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Interval, Interval)]
res2
                then [(Interval, Interval)]
res1
                else [(Interval, Interval)]
res2

-- | Returns true if the intervals are self-overlapping, meaning that for a
-- given dimension d, the stride of d is larger than the aggregate spans of the
-- lower dimensions.
selfOverlap :: scope -> asserts -> [(VName, PrimExp VName)] -> [PrimExp VName] -> [Interval] -> Maybe Interval
selfOverlap :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ [Interval
_] = Maybe Interval
forall a. Maybe a
Nothing
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives' [Interval]
is
  | Just Names
non_negatives <- [VName] -> Names
namesFromList ([VName] -> Names) -> Maybe [VName] -> Maybe Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp VName -> Maybe VName) -> [PrimExp VName] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives' =
      -- TODO: Do we need to do something clever using some ranges of known values?
      let selfOverlap' :: TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' TPrimExp Int64 VName
acc (Interval
x : [Interval]
xs) =
            let interval_span :: TPrimExp Int64 VName
interval_span = (Interval -> TPrimExp Int64 VName
lowerBound Interval
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
numElements Interval
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
x
                res :: Bool
res = [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
acc) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
x)
             in if Bool
res then TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' (TPrimExp Int64 VName
acc TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
interval_span) [Interval]
xs else Interval -> Maybe Interval
forall a. a -> Maybe a
Just Interval
x
          selfOverlap' TPrimExp Int64 VName
_ [] = Maybe Interval
forall a. Maybe a
Nothing
       in TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' TPrimExp Int64 VName
0 ([Interval] -> Maybe Interval) -> [Interval] -> Maybe Interval
forall a b. (a -> b) -> a -> b
$ [Interval] -> [Interval]
forall a. [a] -> [a]
reverse [Interval]
is
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ (Interval
x : [Interval]
_) = Interval -> Maybe Interval
forall a. a -> Maybe a
Just Interval
x
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ [] = Maybe Interval
forall a. Maybe a
Nothing

justLeafExp :: PrimExp VName -> Maybe VName
justLeafExp :: PrimExp VName -> Maybe VName
justLeafExp (LeafExp VName
v PrimType
_) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
justLeafExp PrimExp VName
_ = Maybe VName
forall a. Maybe a
Nothing