-- | A rearrangement is a generalisation of transposition, where the
-- dimensions are arbitrarily permuted.
module Futhark.IR.Prop.Rearrange
  ( rearrangeShape,
    rearrangeInverse,
    rearrangeReach,
    rearrangeCompose,
    isPermutationOf,
    transposeIndex,
    isMapTranspose,
  )
where

import Data.List (sortOn, tails)
import Futhark.Util

-- | Calculate the given permutation of the list.  It is an error if
-- the permutation goes out of bounds.
rearrangeShape :: [Int] -> [a] -> [a]
rearrangeShape :: forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [a]
l = forall a b. (a -> b) -> [a] -> [b]
map Int -> a
pick [Int]
perm
  where
    pick :: Int -> a
pick Int
i
      | Int
0 forall a. Ord a => a -> a -> Bool
<= Int
i, Int
i forall a. Ord a => a -> a -> Bool
< Int
n = [a]
l forall a. [a] -> Int -> a
!! Int
i
      | Bool
otherwise =
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show [Int]
perm forall a. [a] -> [a] -> [a]
++ [Char]
" is not a valid permutation for input."
    n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l

-- | Produce the inverse permutation.
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse [Int]
perm = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
perm [Int
0 ..]

-- | Return the first dimension not affected by the permutation.  For
-- example, the permutation @[1,0,2]@ would return @2@.
rearrangeReach :: [Int] -> Int
rearrangeReach :: [Int] -> Int
rearrangeReach [Int]
perm = case forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a. Eq a => a -> a -> Bool
(/=)) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [[a]]
tails [Int]
perm) (forall a. [a] -> [[a]]
tails [Int
0 .. Int
n forall a. Num a => a -> a -> a
- Int
1]) of
  [] -> Int
n forall a. Num a => a -> a -> a
+ Int
1
  ([Int]
perm', [Int]
_) : [([Int], [Int])]
_ -> Int
n forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm'
  where
    n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm

-- | Compose two permutations, with the second given permutation being
-- applied first.
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose = forall a. [Int] -> [a] -> [a]
rearrangeShape

-- | Check whether the first list is a permutation of the second, and
-- if so, return the permutation.  This will also find identity
-- permutations (i.e. the lists are the same) The implementation is
-- naive and slow.
isPermutationOf :: Eq a => [a] -> [a] -> Maybe [Int]
isPermutationOf :: forall a. Eq a => [a] -> [a] -> Maybe [Int]
isPermutationOf [a]
l1 [a]
l2 =
  case forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick Int
0) (forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> Maybe a
Just [a]
l2) [a]
l1 of
    Just ([Maybe a]
l2', [Int]
perm)
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== forall a. Maybe a
Nothing) [Maybe a]
l2' -> forall a. a -> Maybe a
Just [Int]
perm
    Maybe ([Maybe a], [Int])
_ -> forall a. Maybe a
Nothing
  where
    pick :: Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
    pick :: forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick Int
_ [] a
_ = forall a. Maybe a
Nothing
    pick Int
i (Maybe a
x : [Maybe a]
xs) a
y
      | forall a. a -> Maybe a
Just a
y forall a. Eq a => a -> a -> Bool
== Maybe a
x = forall a. a -> Maybe a
Just (forall a. Maybe a
Nothing forall a. a -> [a] -> [a]
: [Maybe a]
xs, Int
i)
      | Bool
otherwise = do
          ([Maybe a]
xs', Int
v) <- forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick (Int
i forall a. Num a => a -> a -> a
+ Int
1) [Maybe a]
xs a
y
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe a
x forall a. a -> [a] -> [a]
: [Maybe a]
xs', Int
v)

-- | If @l@ is an index into the array @a@, then @transposeIndex k n
-- l@ is an index to the same element in the array @transposeArray k n
-- a@.
transposeIndex :: Int -> Int -> [a] -> [a]
transposeIndex :: forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [a]
l
  | Int
k forall a. Num a => a -> a -> a
+ Int
n forall a. Ord a => a -> a -> Bool
>= forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l =
      let n' :: Int
n' = ((Int
k forall a. Num a => a -> a -> a
+ Int
n) forall a. Integral a => a -> a -> a
`mod` forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l) forall a. Num a => a -> a -> a
- Int
k
       in forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n' [a]
l
  | Int
n forall a. Ord a => a -> a -> Bool
< Int
0,
    ([a]
pre, a
needle : [a]
end) <- forall a. Int -> [a] -> ([a], [a])
splitAt Int
k [a]
l,
    ([a]
beg, [a]
mid) <- forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
pre forall a. Num a => a -> a -> a
+ Int
n) [a]
pre =
      [a]
beg forall a. [a] -> [a] -> [a]
++ [a
needle] forall a. [a] -> [a] -> [a]
++ [a]
mid forall a. [a] -> [a] -> [a]
++ [a]
end
  | ([a]
beg, a
needle : [a]
post) <- forall a. Int -> [a] -> ([a], [a])
splitAt Int
k [a]
l,
    ([a]
mid, [a]
end) <- forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [a]
post =
      [a]
beg forall a. [a] -> [a] -> [a]
++ [a]
mid forall a. [a] -> [a] -> [a]
++ [a
needle] forall a. [a] -> [a] -> [a]
++ [a]
end
  | Bool
otherwise = [a]
l

-- | If @perm@ is conceptually a map of a transposition,
-- @isMapTranspose perm@ returns the number of dimensions being mapped
-- and the number dimension being transposed.  For example, we can
-- consider the permutation @[0,1,4,5,2,3]@ as a map of a transpose,
-- by considering dimensions @[0,1]@, @[4,5]@, and @[2,3]@ as single
-- dimensions each.
--
-- If the input is not a valid permutation, then the result is
-- undefined.
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm
  | [Int]
posttrans forall a. Eq a => a -> a -> Bool
== [forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
posttrans forall a. Num a => a -> a -> a
- Int
1],
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
pretrans,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
posttrans =
      forall a. a -> Maybe a
Just (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
pretrans, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
posttrans)
  | Bool
otherwise =
      forall a. Maybe a
Nothing
  where
    ([Int]
mapped, [Int]
notmapped) = forall {a}. (Eq a, Num a) => a -> [a] -> ([a], [a])
findIncreasingFrom Int
0 [Int]
perm
    ([Int]
pretrans, [Int]
posttrans) = forall {a}. (Eq a, Num a) => [a] -> ([a], [a])
findTransposed [Int]
notmapped

    findIncreasingFrom :: a -> [a] -> ([a], [a])
findIncreasingFrom a
x (a
i : [a]
is)
      | a
i forall a. Eq a => a -> a -> Bool
== a
x =
          let ([a]
js, [a]
ps) = a -> [a] -> ([a], [a])
findIncreasingFrom (a
x forall a. Num a => a -> a -> a
+ a
1) [a]
is
           in (a
i forall a. a -> [a] -> [a]
: [a]
js, [a]
ps)
    findIncreasingFrom a
_ [a]
is =
      ([], [a]
is)

    findTransposed :: [a] -> ([a], [a])
findTransposed [] =
      ([], [])
    findTransposed (a
i : [a]
is) =
      forall {a}. (Eq a, Num a) => a -> [a] -> ([a], [a])
findIncreasingFrom a
i (a
i forall a. a -> [a] -> [a]
: [a]
is)