{-# LANGUAGE OverloadedStrings #-}

module Futhark.IR.Mem.Interval
  ( Interval (..),
    distributeOffset,
    expandOffset,
    intervalOverlap,
    selfOverlap,
    primBool,
    intervalPairs,
    justLeafExp,
  )
where

import Data.Function (on)
import Data.List (maximumBy, minimumBy, (\\))
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Prop
import Futhark.IR.Syntax hiding (Result)
import Futhark.Util

data Interval = Interval
  { Interval -> TPrimExp Int64 VName
lowerBound :: TPrimExp Int64 VName,
    Interval -> TPrimExp Int64 VName
numElements :: TPrimExp Int64 VName,
    Interval -> TPrimExp Int64 VName
stride :: TPrimExp Int64 VName
  }
  deriving (Int -> Interval -> ShowS
[Interval] -> ShowS
Interval -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Interval] -> ShowS
$cshowList :: [Interval] -> ShowS
show :: Interval -> String
$cshow :: Interval -> String
showsPrec :: Int -> Interval -> ShowS
$cshowsPrec :: Int -> Interval -> ShowS
Show, Interval -> Interval -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Interval -> Interval -> Bool
$c/= :: Interval -> Interval -> Bool
== :: Interval -> Interval -> Bool
$c== :: Interval -> Interval -> Bool
Eq)

instance FreeIn Interval where
  freeIn' :: Interval -> FV
freeIn' (Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st) = forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
lb forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
ne forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' TPrimExp Int64 VName
st

distributeOffset :: MonadFail m => AlgSimplify.SofP -> [Interval] -> m [Interval]
distributeOffset :: forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset [] [Interval]
interval = forall (f :: * -> *) a. Applicative f => a -> f a
pure [Interval]
interval
distributeOffset [Prod]
offset [] = forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"Cannot distribute offset " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show [Prod]
offset forall a. Semigroup a => a -> a -> a
<> String
" across empty interval"
distributeOffset [Prod]
offset [Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
1] = forall (f :: * -> *) a. Applicative f => a -> f a
pure [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod]
offset)) TPrimExp Int64 VName
ne TPrimExp Int64 VName
1]
distributeOffset [Prod]
offset (Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 : [Interval]
is)
  | Prod
st <- Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
False [forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
st0],
    Just ([Prod]
before, Prod
quotient, [Prod]
after) <- forall a b. (a -> Maybe b) -> [a] -> Maybe ([a], b, [a])
focusMaybe (Prod -> Prod -> Maybe Prod
`AlgSimplify.maybeDivide` Prod
st) [Prod]
offset =
      forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset ([Prod]
before forall a. Semigroup a => a -> a -> a
<> [Prod]
after) forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod
quotient])) TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 forall a. a -> [a] -> [a]
: [Interval]
is
  | [Prod
st] <- PrimExp VName -> [Prod]
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
st0,
    Just ([Prod]
before, Prod
quotient, [Prod]
after) <- forall a b. (a -> Maybe b) -> [a] -> Maybe ([a], b, [a])
focusMaybe (Prod -> Prod -> Maybe Prod
`AlgSimplify.maybeDivide` Prod
st) [Prod]
offset =
      forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset ([Prod]
before forall a. Semigroup a => a -> a -> a
<> [Prod]
after) forall a b. (a -> b) -> a -> b
$
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
lb forall a. Num a => a -> a -> a
+ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp ([Prod] -> PrimExp VName
AlgSimplify.sumToExp [Prod
quotient])) TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 forall a. a -> [a] -> [a]
: [Interval]
is
  | Bool
otherwise = do
      [Interval]
rest <- forall (m :: * -> *).
MonadFail m =>
[Prod] -> [Interval] -> m [Interval]
distributeOffset [Prod]
offset [Interval]
is
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
ne TPrimExp Int64 VName
st0 forall a. a -> [a] -> [a]
: [Interval]
rest

findMostComplexTerm :: AlgSimplify.SofP -> (AlgSimplify.Prod, AlgSimplify.SofP)
findMostComplexTerm :: [Prod] -> (Prod, [Prod])
findMostComplexTerm [Prod]
prods =
  let max_prod :: Prod
max_prod = forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. Prod -> [PrimExp VName]
AlgSimplify.atoms)) [Prod]
prods
   in (Prod
max_prod, [Prod]
prods forall a. Eq a => [a] -> [a] -> [a]
\\ [Prod
max_prod])

findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride :: [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride [PrimExp VName]
offset_term [Interval]
is =
  let strides :: [PrimExp VName]
strides = forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall b c a. (b -> c) -> (a -> b) -> a -> c
. Interval -> TPrimExp Int64 VName
stride) [Interval]
is
      p :: PrimExp VName
p =
        forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy
          ( forall a. Ord a => a -> a -> Ordering
compare
              forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` ( Prod -> Int
termDifferenceLength
                       forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` \Prod
s -> forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term forall a. Eq a => [a] -> [a] -> [a]
\\ Prod -> [PrimExp VName]
AlgSimplify.atoms Prod
s))
                       forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> [Prod]
AlgSimplify.simplify0
                   )
          )
          [PrimExp VName]
strides
   in ( PrimExp VName
p,
        ([PrimExp VName]
offset_term \\) forall a b. (a -> b) -> a -> b
$
          Prod -> [PrimExp VName]
AlgSimplify.atoms forall a b. (a -> b) -> a -> b
$
            forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` \Prod
s -> forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term forall a. Eq a => [a] -> [a] -> [a]
\\ Prod -> [PrimExp VName]
AlgSimplify.atoms Prod
s)) forall a b. (a -> b) -> a -> b
$
              PrimExp VName -> [Prod]
AlgSimplify.simplify0 PrimExp VName
p
      )
  where
    termDifferenceLength :: Prod -> Int
termDifferenceLength (AlgSimplify.Prod Bool
_ [PrimExp VName]
xs) = forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimExp VName]
offset_term forall a. Eq a => [a] -> [a] -> [a]
\\ [PrimExp VName]
xs)

expandOffset :: AlgSimplify.SofP -> [Interval] -> Maybe AlgSimplify.SofP
expandOffset :: [Prod] -> [Interval] -> Maybe [Prod]
expandOffset [] [Interval]
_ = forall a. Maybe a
Nothing
expandOffset [Prod]
offset [Interval]
i1
  | (AlgSimplify.Prod Bool
b [PrimExp VName]
term_to_add, [Prod]
offset_rest) <- [Prod] -> (Prod, [Prod])
findMostComplexTerm [Prod]
offset, -- Find gnb
    (PrimExp VName
closest_stride, [PrimExp VName]
first_term_divisor) <- [PrimExp VName] -> [Interval] -> (PrimExp VName, [PrimExp VName])
findClosestStride [PrimExp VName]
term_to_add [Interval]
i1, -- find (nb-b, g)
    [Prod]
target <- [Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
b forall a b. (a -> b) -> a -> b
$ PrimExp VName
closest_stride forall a. a -> [a] -> [a]
: [PrimExp VName]
first_term_divisor], -- g(nb-b)
    [Prod]
diff <- PrimExp VName -> [Prod]
AlgSimplify.sumOfProducts forall a b. (a -> b) -> a -> b
$ [Prod] -> PrimExp VName
AlgSimplify.sumToExp forall a b. (a -> b) -> a -> b
$ Bool -> [PrimExp VName] -> Prod
AlgSimplify.Prod Bool
b [PrimExp VName]
term_to_add forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate [Prod]
target, -- gnb - gnb + gb = gnb - g(nb-b)
    [Prod]
replacement <- [Prod]
target forall a. Semigroup a => a -> a -> a
<> [Prod]
diff -- gnb = g(nb-b) + gnb - gnb + gb
    =
      forall a. a -> Maybe a
Just ([Prod]
replacement forall a. Semigroup a => a -> a -> a
<> [Prod]
offset_rest)

intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap :: [(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives (Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
ne1 TPrimExp Int64 VName
st1) (Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
ne2 TPrimExp Int64 VName
st2)
  | TPrimExp Int64 VName
st1 forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
lb2,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName
lb1 forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ne1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
lb2 =
      Bool
False
  | TPrimExp Int64 VName
st1 forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
lb1,
    [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName
lb2 forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ne2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
lb1 =
      Bool
False
  | Bool
otherwise = Bool
True

primBool :: TPrimExp Bool VName -> Maybe Bool
primBool :: TPrimExp Bool VName -> Maybe Bool
primBool TPrimExp Bool VName
p
  | Just (BoolValue Bool
b) <- forall v (m :: * -> *).
(Pretty v, MonadFail m) =>
(v -> m PrimValue) -> PrimExp v -> m PrimValue
evalPrimExp (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Bool VName
p = forall a. a -> Maybe a
Just Bool
b
  | Bool
otherwise = forall a. Maybe a
Nothing

intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs :: [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' []
  where
    intervalPairs' :: [(Interval, Interval)] -> [Interval] -> [Interval] -> [(Interval, Interval)]
    intervalPairs' :: [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' [(Interval, Interval)]
acc [] [] = forall a. [a] -> [a]
reverse [(Interval, Interval)]
acc
    intervalPairs' [(Interval, Interval)]
acc (i :: Interval
i@(Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
_ TPrimExp Int64 VName
st) : [Interval]
is) [] = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
1 TPrimExp Int64 VName
st) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is []
    intervalPairs' [(Interval, Interval)]
acc [] (i :: Interval
i@(Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
_ TPrimExp Int64 VName
st) : [Interval]
is) = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb TPrimExp Int64 VName
1 TPrimExp Int64 VName
st, Interval
i) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [] [Interval]
is
    intervalPairs' [(Interval, Interval)]
acc (i1 :: Interval
i1@(Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
_ TPrimExp Int64 VName
st1) : [Interval]
is1) (i2 :: Interval
i2@(Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
_ TPrimExp Int64 VName
st2) : [Interval]
is2)
      | TPrimExp Int64 VName
st1 forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
st2 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i1, Interval
i2) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is1 [Interval]
is2
      | Bool
otherwise =
          let res1 :: [(Interval, Interval)]
res1 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((Interval
i1, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb1 TPrimExp Int64 VName
1 TPrimExp Int64 VName
st1) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) [Interval]
is1 (Interval
i2 forall a. a -> [a] -> [a]
: [Interval]
is2)
              res2 :: [(Interval, Interval)]
res2 = [(Interval, Interval)]
-> [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs' ((TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
lb2 TPrimExp Int64 VName
1 TPrimExp Int64 VName
st2, Interval
i2) forall a. a -> [a] -> [a]
: [(Interval, Interval)]
acc) (Interval
i1 forall a. a -> [a] -> [a]
: [Interval]
is1) [Interval]
is2
           in if forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Interval, Interval)]
res1 forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Interval, Interval)]
res2
                then [(Interval, Interval)]
res1
                else [(Interval, Interval)]
res2

-- | Returns true if the intervals are self-overlapping, meaning that for a
-- given dimension d, the stride of d is larger than the aggregate spans of the
-- lower dimensions.
selfOverlap :: scope -> asserts -> [(VName, PrimExp VName)] -> [PrimExp VName] -> [Interval] -> Maybe Interval
selfOverlap :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ [Interval
_] = forall a. Maybe a
Nothing
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives' [Interval]
is
  | Just Names
non_negatives <- [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives' =
      -- TODO: Do we need to do something clever using some ranges of known values?
      let selfOverlap' :: TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' TPrimExp Int64 VName
acc (Interval
x : [Interval]
xs) =
            let interval_span :: TPrimExp Int64 VName
interval_span = (Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
x
                res :: Bool
res = [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish [(VName, PrimExp VName)]
less_thans Names
non_negatives (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
acc) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
x)
             in if Bool
res then TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' (TPrimExp Int64 VName
acc forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
interval_span) [Interval]
xs else forall a. a -> Maybe a
Just Interval
x
          selfOverlap' TPrimExp Int64 VName
_ [] = forall a. Maybe a
Nothing
       in TPrimExp Int64 VName -> [Interval] -> Maybe Interval
selfOverlap' TPrimExp Int64 VName
0 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Interval]
is
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ (Interval
x : [Interval]
_) = forall a. a -> Maybe a
Just Interval
x
selfOverlap scope
_ asserts
_ [(VName, PrimExp VName)]
_ [PrimExp VName]
_ [] = forall a. Maybe a
Nothing

justLeafExp :: PrimExp VName -> Maybe VName
justLeafExp :: PrimExp VName -> Maybe VName
justLeafExp (LeafExp VName
v PrimType
_) = forall a. a -> Maybe a
Just VName
v
justLeafExp PrimExp VName
_ = forall a. Maybe a
Nothing