{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

-- | Some simplification rules for t'BasicOp'.
module Futhark.Optimise.Simplify.Rules.BasicOp
  ( basicOpRules,
  )
where

import Control.Monad
import Data.List (find, foldl', isSuffixOf, sort)
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Simple
import Futhark.Util

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

data ConcatArg
  = ArgArrayLit [SubExp]
  | ArgReplicate [SubExp] SubExp
  | ArgVar VName

toConcatArg :: ST.SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg :: SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg SymbolTable rep
vtable VName
v =
  case VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
    Just (ArrayLit [SubExp]
ses Type
_, Certs
cs) ->
      ([SubExp] -> ConcatArg
ArgArrayLit [SubExp]
ses, Certs
cs)
    Just (Replicate Shape
shape SubExp
se, Certs
cs) ->
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate [Int -> Shape -> SubExp
shapeSize Int
0 Shape
shape] SubExp
se, Certs
cs)
    Maybe (BasicOp, Certs)
_ ->
      (VName -> ConcatArg
ArgVar VName
v, Certs
forall a. Monoid a => a
mempty)

fromConcatArg ::
  MonadBuilder m =>
  Type ->
  (ConcatArg, Certs) ->
  m VName
fromConcatArg :: Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
t (ArgArrayLit [SubExp]
ses, Certs
cs) =
  Certs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_lit" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp]
ses (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
fromConcatArg Type
elem_type (ArgReplicate [SubExp]
ws SubExp
se, Certs
cs) = do
  let elem_shape :: Shape
elem_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
elem_type
  Certs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ do
    SubExp
w <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"concat_rep_w" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws)
    String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rep" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0 Shape
elem_shape SubExp
w) SubExp
se
fromConcatArg Type
_ (ArgVar VName
v, Certs
_) =
  VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

fuseConcatArg ::
  [(ConcatArg, Certs)] ->
  (ConcatArg, Certs) ->
  [(ConcatArg, Certs)]
fuseConcatArg :: [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
xs (ArgArrayLit [], Certs
_) =
  [(ConcatArg, Certs)]
xs
fuseConcatArg [(ConcatArg, Certs)]
xs (ArgReplicate [SubExp
w] SubExp
se, Certs
cs)
  | SubExp -> Bool
isCt0 SubExp
w =
    [(ConcatArg, Certs)]
xs
  | SubExp -> Bool
isCt1 SubExp
w =
    [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
xs ([SubExp] -> ConcatArg
ArgArrayLit [SubExp
se], Certs
cs)
fuseConcatArg ((ArgArrayLit [SubExp]
x_ses, Certs
x_cs) : [(ConcatArg, Certs)]
xs) (ArgArrayLit [SubExp]
y_ses, Certs
y_cs) =
  ([SubExp] -> ConcatArg
ArgArrayLit ([SubExp]
x_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ses), Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
y_cs) (ConcatArg, Certs) -> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs
fuseConcatArg ((ArgReplicate [SubExp]
x_ws SubExp
x_se, Certs
x_cs) : [(ConcatArg, Certs)]
xs) (ArgReplicate [SubExp]
y_ws SubExp
y_se, Certs
y_cs)
  | SubExp
x_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y_se =
    ([SubExp] -> SubExp -> ConcatArg
ArgReplicate ([SubExp]
x_ws [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ws) SubExp
x_se, Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
y_cs) (ConcatArg, Certs) -> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs
fuseConcatArg [(ConcatArg, Certs)]
xs (ConcatArg, Certs)
y =
  (ConcatArg, Certs)
y (ConcatArg, Certs) -> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs

simplifyConcat :: BuilderOps rep => BottomUpRuleBasicOp rep
-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: BottomUpRuleBasicOp rep
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat rep
pat StmAux (ExpDec rep)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | Just Int
r <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Maybe Type -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
x SymbolTable rep
vtable,
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
    Just (VName
x', Certs
x_cs) <- [Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs', [Certs]
xs_cs) <- [(VName, Certs)] -> ([VName], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certs)] -> ([VName], [Certs]))
-> Maybe [(VName, Certs)] -> Maybe ([VName], [Certs])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certs))
-> [VName] -> Maybe [(VName, Certs)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm) [VName]
xs = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    VName
concat_rearrange <-
      Certs -> RuleM rep VName -> RuleM rep VName
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> [Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
xs_cs) (RuleM rep VName -> RuleM rep VName)
-> RuleM rep VName -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rearrange" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
    Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
  where
    transposedBy :: [Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm1 VName
v =
      case VName -> SymbolTable rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable of
        Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certs
vcs)
          | [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certs) -> Maybe (VName, Certs)
forall a. a -> Maybe a
Just (VName
v', Certs
vcs)
        Maybe (ExpT rep, Certs)
_ -> Maybe (VName, Certs)
forall a. Maybe a
Nothing

-- Removing a concatenation that involves only a single array.  This
-- may be produced as a result of other simplification rules.
simplifyConcat (SymbolTable rep, UsageTable)
_ Pat rep
pat StmAux (ExpDec rep)
aux (Concat Int
_ VName
x [] SubExp
_) =
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
    -- Still need a copy because Concat produces a fresh array.
    StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
x
-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat rep
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> [Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
xs_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
i VName
x' ([VName]
zs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where
    (VName
x' : [VName]
zs, Certs
x_cs) = VName -> ([VName], Certs)
isConcat VName
x
    ([[VName]]
xs', [Certs]
xs_cs) = [([VName], Certs)] -> ([[VName]], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certs)] -> ([[VName]], [Certs]))
-> [([VName], Certs)] -> ([[VName]], [Certs])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certs)) -> [VName] -> [([VName], Certs)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certs)
isConcat [VName]
xs
    isConcat :: VName -> ([VName], Certs)
isConcat VName
v = case VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
      Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certs
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certs
v_cs)
      Maybe (BasicOp, Certs)
_ -> ([VName
v], Certs
forall a. Monoid a => a
mempty)

-- Fusing arguments to the concat when possible.  Only done when
-- concatenating along the outer dimension for now.
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat rep
pat StmAux (ExpDec rep)
aux (Concat Int
0 VName
x [VName]
xs SubExp
outer_w)
  | -- We produce the to-be-concatenated arrays in reverse order, so
    -- reverse them back.
    (ConcatArg, Certs)
y : [(ConcatArg, Certs)]
ys <-
      [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forSingleArray ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> a -> b
$
        [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. [a] -> [a]
reverse ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> a -> b
$
          ([(ConcatArg, Certs)]
 -> (ConcatArg, Certs) -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)]
-> [(ConcatArg, Certs)]
-> [(ConcatArg, Certs)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
forall a. Monoid a => a
mempty ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> a -> b
$
            (VName -> (ConcatArg, Certs)) -> [VName] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (SymbolTable rep -> VName -> (ConcatArg, Certs)
forall rep. SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg SymbolTable rep
vtable) ([VName] -> [(ConcatArg, Certs)])
-> [VName] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
    [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [(ConcatArg, Certs)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ConcatArg, Certs)]
ys =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      Type
elem_type <- VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
      VName
y' <- Type -> (ConcatArg, Certs) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
elem_type (ConcatArg, Certs)
y
      [VName]
ys' <- ((ConcatArg, Certs) -> RuleM rep VName)
-> [(ConcatArg, Certs)] -> RuleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Type -> (ConcatArg, Certs) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
elem_type) [(ConcatArg, Certs)]
ys
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
y' [VName]
ys' SubExp
outer_w
  where
    -- If we fuse so much that there is only a single input left, then
    -- it must have the right size.
    forSingleArray :: [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forSingleArray [(ArgReplicate [SubExp]
_ SubExp
v, Certs
cs)] =
      [([SubExp] -> SubExp -> ConcatArg
ArgReplicate [SubExp
outer_w] SubExp
v, Certs
cs)]
    forSingleArray [(ConcatArg, Certs)]
ys = [(ConcatArg, Certs)]
ys
simplifyConcat (SymbolTable rep, UsageTable)
_ Pat rep
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

ruleBasicOp :: BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp :: TopDownRuleBasicOp rep
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux BasicOp
op
  | Just (BasicOp
op', Certs
cs) <- VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)
forall rep.
VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)
applySimpleRules VarLookup rep
defOf TypeLookup
seType BasicOp
op =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp BasicOp
op'
  where
    defOf :: VarLookup rep
defOf = (VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` TopDown rep
vtable)
    seType :: TypeLookup
seType (Var VName
v) = VName -> TopDown rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v TopDown rep
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
_ (Update Safety
_ VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch {}, Certs
_) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
-- If we are writing a single-element slice from some array, and the
-- element of that array can be computed as a PrimExp based on the
-- index, let's just write that instead.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Update Safety
safety VName
src (Slice [DimSlice SubExp
i SubExp
n SubExp
s]) (Var VName
v))
  | SubExp -> Bool
isCt1 SubExp
n,
    SubExp -> Bool
isCt1 SubExp
s,
    Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown rep -> Maybe Indexed
forall rep.
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0] TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      SubExp
e' <- String -> PrimExp VName -> RuleM rep SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
src ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
e'
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
_ (Update Safety
_ VName
dest Slice SubExp
destis (Var VName
v))
  | Just (ExpT rep
e, Certs
_) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    ExpT rep -> Bool
arrayFrom ExpT rep
e =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where
    arrayFrom :: ExpT rep -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
      | Just (ExpT rep
e', Certs
_) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
copy_v TopDown rep
vtable =
        ExpT rep -> Bool
arrayFrom ExpT rep
e'
    arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
      VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
    arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
      | Just (Replicate Shape
dest_shape SubExp
dest_se, Certs
_) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
dest TopDown rep
vtable,
        SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
        Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
        Bool
True
    arrayFrom ExpT rep
_ =
      Bool
False
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
_ (Update Safety
_ VName
dest Slice SubExp
is SubExp
se)
  | Just Type
dest_t <- VName -> TopDown rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
dest TopDown rep
vtable,
    Shape -> Slice SubExp -> Bool
isFullSlice (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t) Slice SubExp
is = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
    case SubExp
se of
      Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
        VName
v_reshaped <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t) VName
v
        Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_reshaped
      SubExp
_ -> Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
dest_t
ruleBasicOp TopDown rep
vtable Pat rep
pat (StmAux Certs
cs1 Attrs
attrs ExpDec rep
_) (Update Safety
safety1 VName
dest1 Slice SubExp
is1 (Var VName
v1))
  | Just (Update Safety
safety2 VName
dest2 Slice SubExp
is2 SubExp
se2, Certs
cs2) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    Just (Copy VName
v3, Certs
cs3) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
dest2 TopDown rep
vtable,
    Just (Index VName
v4 Slice SubExp
is4, Certs
cs4) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v3 TopDown rep
vtable,
    Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1,
    VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs2 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs3 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs4) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
        Slice SubExp
is5 <- Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp))
-> Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is2)
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update (Safety -> Safety -> Safety
forall a. Ord a => a -> a -> a
max Safety
safety1 Safety
safety2) VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  where
    simplifyWith :: SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith (Var VName
v) SubExp
x
      | Just Stm rep
stm <- VName -> TopDown rep -> Maybe (Stm rep)
forall rep. VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
        If SubExp
p BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_ <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        Just (SubExp
y, SubExp
z) <-
          VName
-> Pat rep -> BodyT rep -> BodyT rep -> Maybe (SubExp, SubExp)
forall dec rep rep.
VName
-> PatT dec -> BodyT rep -> BodyT rep -> Maybe (SubExp, SubExp)
returns VName
v (Stm rep -> Pat rep
forall rep. Stm rep -> Pat rep
stmPat Stm rep
stm) BodyT rep
tbranch BodyT rep
fbranch,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Names
forall rep. Body rep -> Names
boundInBody BodyT rep
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Names
forall rep. Body rep -> Names
boundInBody BodyT rep
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM rep () -> Maybe (RuleM rep ())
forall a. a -> Maybe a
Just (RuleM rep () -> Maybe (RuleM rep ()))
-> RuleM rep () -> Maybe (RuleM rep ())
forall a b. (a -> b) -> a -> b
$ do
        SubExp
eq_x_y <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
        SubExp
eq_x_z <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
        SubExp
p_and_eq_x_y <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
        SubExp
not_p <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"not_p" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
        SubExp
not_p_and_eq_x_z <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
        Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
    simplifyWith SubExp
_ SubExp
_ =
      Maybe (RuleM rep ())
forall a. Maybe a
Nothing

    returns :: VName
-> PatT dec -> BodyT rep -> BodyT rep -> Maybe (SubExp, SubExp)
returns VName
v PatT dec
ifpat BodyT rep
tbranch BodyT rep
fbranch =
      ((PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> ([(PatElemT dec, (SubExp, SubExp))]
    -> Maybe (PatElemT dec, (SubExp, SubExp)))
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (SubExp, SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElemT dec, (SubExp, SubExp)) -> Bool)
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((PatElemT dec, (SubExp, SubExp)) -> VName)
-> (PatElemT dec, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT dec -> VName)
-> ((PatElemT dec, (SubExp, SubExp)) -> PatElemT dec)
-> (PatElemT dec, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, (SubExp, SubExp)) -> PatElemT dec
forall a b. (a, b) -> a
fst) ([(PatElemT dec, (SubExp, SubExp))] -> Maybe (SubExp, SubExp))
-> [(PatElemT dec, (SubExp, SubExp))] -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
        [PatElemT dec]
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatT dec -> [PatElemT dec]
forall dec. PatT dec -> [PatElemT dec]
patElems PatT dec
ifpat) ([(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
          [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (BodyT rep -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult BodyT rep
tbranch)) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (BodyT rep -> [SubExpRes]
forall rep. BodyT rep -> [SubExpRes]
bodyResult BodyT rep
fbranch))
ruleBasicOp TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) se :: SubExp
se@Constant {}) =
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (Replicate Shape
_ SubExp
se)
  | [Acc {}] <- Pat rep -> [Type]
forall dec. Typed dec => PatT dec -> [Type]
patTypes Pat rep
pat =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) (Var VName
v)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  Type
v_t <- VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$
      if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
v_t
        then SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        else VName -> BasicOp
Copy VName
v
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
_ (Replicate Shape
shape (Var VName
v))
  | Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certs
cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape2) SubExp
se
ruleBasicOp TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (ArrayLit (SubExp
se : [SubExp]
ses) Type
_)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se) [SubExp]
ses =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      let n :: SubExp
n = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1 :: Int64)
       in Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Index VName
idd Slice SubExp
slice)
  | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
    Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
idd2), Certs
idd_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
idd TopDown rep
vtable,
    ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      case ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape of
        Just [SubExp]
_ ->
          Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
idd_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
        Maybe [SubExp]
Nothing -> do
          -- Linearise indices and map to old index space.
          [SubExp]
oldshape <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> RuleM rep Type -> RuleM rep [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
idd2
          let new_inds :: [TPrimExp Int64 VName]
new_inds =
                [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
oldshape)
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
inds)
          [SubExp]
new_inds' <-
            (TPrimExp Int64 VName -> RuleM rep SubExp)
-> [TPrimExp Int64 VName] -> RuleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> RuleM rep SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index") [TPrimExp Int64 VName]
new_inds
          Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
idd_cs (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'

-- Copying an iota is pointless; just make it an iota instead.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Copy VName
v)
  | Just (Iota SubExp
n SubExp
x SubExp
s IntType
it, Certs
v_cs) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
x SubExp
s IntType
it
-- Handle identity permutation.
ruleBasicOp TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (Rearrange [Int]
perm VName
v)
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certs
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
    -- Rearranging a rearranging: compose the permutations.
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certs
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certs
v2_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v2 TopDown rep
vtable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let offsets' :: [SubExp]
offsets' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
    VName
rearrange_rotate <- String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rearrange_rotate" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
v3
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
v_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
v2_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate

-- Rearranging a replicate where the outer dimension is left untouched.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v1)
  | Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certs
v1_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v1 TopDown rep
vtable,
    Int
num_dims <- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
    ([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
    [Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0 .. [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v1_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
          SubExp
v <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
          Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
dims SubExp
v

-- A zero-rotation is identity.
ruleBasicOp TopDown rep
_ Pat rep
pat StmAux (ExpDec rep)
_ (Rotate [SubExp]
offsets VName
v)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Rotate [SubExp]
offsets VName
v)
  | Just (BasicOp (Rearrange [Int]
perm VName
v2), Certs
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certs
v2_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v2 TopDown rep
vtable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let offsets2' :: [SubExp]
offsets2' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
        addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"summed_offset" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y
    [SubExp]
offsets' <- (SubExp -> SubExp -> RuleM rep SubExp)
-> [SubExp] -> [SubExp] -> RuleM rep [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
    VName
rotate_rearrange <-
      StmAux (ExpDec rep) -> RuleM rep VName -> RuleM rep VName
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep VName -> RuleM rep VName)
-> RuleM rep VName -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rotate_rearrange" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v3
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
v_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
v2_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
rotate_rearrange

-- Combining Rotates.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Rotate [SubExp]
offsets1 VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certs
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
offsets <- (SubExp -> SubExp -> RuleM rep SubExp)
-> [SubExp] -> [SubExp] -> RuleM rep [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets VName
v2
  where
    add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"offset" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y

-- If we see an Update with a scalar where the value to be written is
-- the result of indexing some other array, then we convert it into an
-- Update with a slice of that array.  This matters when the arrays
-- are far away (on the GPU, say), because it avoids a copy of the
-- scalar to and from the host.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (Update Safety
safety VName
arr_x (Slice [DimIndex SubExp]
slice_x) (Var VName
v))
  | Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice_x),
    Just (Index VName
arr_y (Slice [DimIndex SubExp]
slice_y), Certs
cs_y) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v TopDown rep
vtable,
    VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
arr_y TopDown rep
vtable,
    -- XXX: we should check for proper aliasing here instead.
    VName
arr_y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
arr_x,
    Just ([DimIndex SubExp]
slice_x_bef, DimFix SubExp
i, []) <- Int
-> [DimIndex SubExp]
-> Maybe ([DimIndex SubExp], DimIndex SubExp, [DimIndex SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
slice_x,
    Just ([DimIndex SubExp]
slice_y_bef, DimFix SubExp
j, []) <- Int
-> [DimIndex SubExp]
-> Maybe ([DimIndex SubExp], DimIndex SubExp, [DimIndex SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
slice_y = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let slice_x' :: Slice SubExp
slice_x' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice_x_bef [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
        slice_y' :: Slice SubExp
slice_y' = [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice_y_bef [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
    VName
v' <- String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_slice") (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
    Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs_y (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

-- Simplify away 0<=i when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (CmpOp CmpSle {} SubExp
x SubExp
y)
  | Constant (IntValue (Int64Value Int64
0)) <- SubExp
x,
    Var VName
v <- SubExp
y,
    Just SubExp
_ <- VName -> TopDown rep -> Maybe SubExp
forall rep. VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away i<n when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} SubExp
x SubExp
y)
  | Var VName
v <- SubExp
x,
    Just SubExp
n <- VName -> TopDown rep -> Maybe SubExp
forall rep. VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable,
    SubExp
n SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away x<0 when 'x' has been used as array size.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} (Var VName
x) SubExp
y)
  | SubExp -> Bool
isCt0 SubExp
y,
    Bool -> (Entry rep -> Bool) -> Maybe (Entry rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry rep -> Bool
forall rep. Entry rep -> Bool
ST.entryIsSize (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False
-- Remove certificates for variables whose definition already contain
-- that certificate.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (SubExp (Var VName
v))
  | [VName]
cs <- Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
cs,
    Just [VName]
v_cs <- Certs -> [VName]
unCerts (Certs -> [VName]) -> (Stm rep -> Certs) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts (Stm rep -> [VName]) -> Maybe (Stm rep) -> Maybe [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> TopDown rep -> Maybe (Stm rep)
forall rep. VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
    [VName]
cs' <- (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
v_cs) [VName]
cs,
    [VName]
cs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
cs =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([VName] -> Certs
Certs [VName]
cs') (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
-- Remove UpdateAccs that contribute the neutral value, which is
-- always a no-op.
ruleBasicOp TopDown rep
vtable Pat rep
pat StmAux (ExpDec rep)
aux (UpdateAcc VName
acc [SubExp]
_ [SubExp]
vs)
  | Pat [PatElemT (LetDec rep)
pe] <- Pat rep
pat,
    Acc VName
token Shape
_ [Type]
_ NoUniqueness
_ <- PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe,
    Just (Shape
_, [VName]
_, Just (Lambda rep
_, [SubExp]
ne)) <- Entry rep -> Maybe (WithAccInput rep)
forall rep. Entry rep -> Maybe (WithAccInput rep)
ST.entryAccInput (Entry rep -> Maybe (WithAccInput rep))
-> Maybe (Entry rep) -> Maybe (WithAccInput rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
token TopDown rep
vtable,
    [SubExp]
vs [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
ne =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind Pat rep
Pat (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
ruleBasicOp TopDown rep
_ Pat rep
_ StmAux (ExpDec rep)
_ BasicOp
_ =
  Rule rep
forall rep. Rule rep
Skip

topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: [TopDownRule rep]
topDownRules =
  [ RuleBasicOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp
  ]

bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
bottomUpRules :: [BottomUpRule rep]
bottomUpRules =
  [ RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyConcat
  ]

-- | A set of simplification rules for t'BasicOp's.  Includes rules
-- from "Futhark.Optimise.Simplify.Rules.Simple".
basicOpRules :: (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules :: RuleBook rep
basicOpRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules