{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

-- | Some simplification rules for 'BasicOp'.
module Futhark.Optimise.Simplify.Rules.BasicOp
  ( basicOpRules,
  )
where

import Control.Monad
import Data.List (find, foldl', isSuffixOf, sort)
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Simple
import Futhark.Util

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

data ConcatArg
  = ArgArrayLit [SubExp]
  | ArgReplicate [SubExp] SubExp
  | ArgVar VName

toConcatArg :: ST.SymbolTable rep -> VName -> (ConcatArg, Certificates)
toConcatArg :: forall rep. SymbolTable rep -> VName -> (ConcatArg, Certificates)
toConcatArg SymbolTable rep
vtable VName
v =
  case VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
    Just (ArrayLit [SubExp]
ses Type
_, Certificates
cs) ->
      ([SubExp] -> ConcatArg
ArgArrayLit [SubExp]
ses, Certificates
cs)
    Just (Replicate Shape
shape SubExp
se, Certificates
cs) ->
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate [Int -> Shape -> SubExp
shapeSize Int
0 Shape
shape] SubExp
se, Certificates
cs)
    Maybe (BasicOp, Certificates)
_ ->
      (VName -> ConcatArg
ArgVar VName
v, Certificates
forall a. Monoid a => a
mempty)

fromConcatArg ::
  MonadBinder m =>
  Type ->
  (ConcatArg, Certificates) ->
  m VName
fromConcatArg :: forall (m :: * -> *).
MonadBinder m =>
Type -> (ConcatArg, Certificates) -> m VName
fromConcatArg Type
t (ArgArrayLit [SubExp]
ses, Certificates
cs) =
  Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_lit" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp]
ses (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
t
fromConcatArg Type
elem_type (ArgReplicate [SubExp]
ws SubExp
se, Certificates
cs) = do
  let elem_shape :: Shape
elem_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
elem_type
  Certificates -> m VName -> m VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ do
    SubExp
w <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"concat_rep_w" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Rep m))
toExp ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws)
    String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rep" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0 Shape
elem_shape SubExp
w) SubExp
se
fromConcatArg Type
_ (ArgVar VName
v, Certificates
_) =
  VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

fuseConcatArg ::
  [(ConcatArg, Certificates)] ->
  (ConcatArg, Certificates) ->
  [(ConcatArg, Certificates)]
fuseConcatArg :: [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
xs (ArgArrayLit [], Certificates
_) =
  [(ConcatArg, Certificates)]
xs
fuseConcatArg [(ConcatArg, Certificates)]
xs (ArgReplicate [SubExp
w] SubExp
se, Certificates
cs)
  | SubExp -> Bool
isCt0 SubExp
w =
    [(ConcatArg, Certificates)]
xs
  | SubExp -> Bool
isCt1 SubExp
w =
    [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
xs ([SubExp] -> ConcatArg
ArgArrayLit [SubExp
se], Certificates
cs)
fuseConcatArg ((ArgArrayLit [SubExp]
x_ses, Certificates
x_cs) : [(ConcatArg, Certificates)]
xs) (ArgArrayLit [SubExp]
y_ses, Certificates
y_cs) =
  ([SubExp] -> ConcatArg
ArgArrayLit ([SubExp]
x_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ses), Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
y_cs) (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
fuseConcatArg ((ArgReplicate [SubExp]
x_ws SubExp
x_se, Certificates
x_cs) : [(ConcatArg, Certificates)]
xs) (ArgReplicate [SubExp]
y_ws SubExp
y_se, Certificates
y_cs)
  | SubExp
x_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y_se =
    ([SubExp] -> SubExp -> ConcatArg
ArgReplicate ([SubExp]
x_ws [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ws) SubExp
x_se, Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
y_cs) (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs
fuseConcatArg [(ConcatArg, Certificates)]
xs (ConcatArg, Certificates)
y =
  (ConcatArg, Certificates)
y (ConcatArg, Certificates)
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certificates)]
xs

simplifyConcat :: BinderOps rep => BottomUpRuleBasicOp rep
-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: forall rep. BinderOps rep => BottomUpRuleBasicOp rep
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pattern rep
pat StmAux (ExpDec rep)
_ (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | Just Int
r <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Maybe Type -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
x SymbolTable rep
vtable,
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
    Just (VName
x', Certificates
x_cs) <- [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs', [Certificates]
xs_cs) <- [(VName, Certificates)] -> ([VName], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Certificates)] -> ([VName], [Certificates]))
-> Maybe [(VName, Certificates)] -> Maybe ([VName], [Certificates])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> Maybe (VName, Certificates))
-> [VName] -> Maybe [(VName, Certificates)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm) [VName]
xs = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    VName
concat_rearrange <-
      Certificates -> RuleM rep VName -> RuleM rep VName
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM rep VName -> RuleM rep VName)
-> RuleM rep VName -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rearrange" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
x' [VName]
xs' SubExp
new_d
    Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
  where
    transposedBy :: [Int] -> VName -> Maybe (VName, Certificates)
transposedBy [Int]
perm1 VName
v =
      case VName -> SymbolTable rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v SymbolTable rep
vtable of
        Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certificates
vcs)
          | [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certificates) -> Maybe (VName, Certificates)
forall a. a -> Maybe a
Just (VName
v', Certificates
vcs)
        Maybe (ExpT rep, Certificates)
_ -> Maybe (VName, Certificates)
forall a. Maybe a
Nothing

-- Removing a concatenation that involves only a single array.  This
-- may be produced as a result of other simplification rules.
simplifyConcat (SymbolTable rep, UsageTable)
_ Pattern rep
pat StmAux (ExpDec rep)
aux (Concat Int
_ VName
x [] SubExp
_) =
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
    -- Still need a copy because Concat produces a fresh array.
    StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
x
-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pattern rep
pat (StmAux Certificates
cs Attrs
attrs ExpDec rep
_) (Concat Int
i VName
x [VName]
xs SubExp
new_d)
  | VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
x_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> [Certificates] -> Certificates
forall a. Monoid a => [a] -> a
mconcat [Certificates]
xs_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
i VName
x' ([VName]
zs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where
    (VName
x' : [VName]
zs, Certificates
x_cs) = VName -> ([VName], Certificates)
isConcat VName
x
    ([[VName]]
xs', [Certificates]
xs_cs) = [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certificates)] -> ([[VName]], [Certificates]))
-> [([VName], Certificates)] -> ([[VName]], [Certificates])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certificates))
-> [VName] -> [([VName], Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certificates)
isConcat [VName]
xs
    isConcat :: VName -> ([VName], Certificates)
isConcat VName
v = case VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
      Just (Concat Int
j VName
y [VName]
ys SubExp
_, Certificates
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certificates
v_cs)
      Maybe (BasicOp, Certificates)
_ -> ([VName
v], Certificates
forall a. Monoid a => a
mempty)

-- Fusing arguments to the concat when possible.  Only done when
-- concatenating along the outer dimension for now.
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pattern rep
pat StmAux (ExpDec rep)
aux (Concat Int
0 VName
x [VName]
xs SubExp
outer_w)
  | -- We produce the to-be-concatenated arrays in reverse order, so
    -- reverse them back.
    (ConcatArg, Certificates)
y : [(ConcatArg, Certificates)]
ys <-
      [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forSingleArray ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
        [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a. [a] -> [a]
reverse ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
          ([(ConcatArg, Certificates)]
 -> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)]
-> [(ConcatArg, Certificates)]
-> [(ConcatArg, Certificates)]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [(ConcatArg, Certificates)]
-> (ConcatArg, Certificates) -> [(ConcatArg, Certificates)]
fuseConcatArg [(ConcatArg, Certificates)]
forall a. Monoid a => a
mempty ([(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)])
-> [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$
            (VName -> (ConcatArg, Certificates))
-> [VName] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (SymbolTable rep -> VName -> (ConcatArg, Certificates)
forall rep. SymbolTable rep -> VName -> (ConcatArg, Certificates)
toConcatArg SymbolTable rep
vtable) ([VName] -> [(ConcatArg, Certificates)])
-> [VName] -> [(ConcatArg, Certificates)]
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
    [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [(ConcatArg, Certificates)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ConcatArg, Certificates)]
ys =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      Type
elem_type <- VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
      VName
y' <- Type -> (ConcatArg, Certificates) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
Type -> (ConcatArg, Certificates) -> m VName
fromConcatArg Type
elem_type (ConcatArg, Certificates)
y
      [VName]
ys' <- ((ConcatArg, Certificates) -> RuleM rep VName)
-> [(ConcatArg, Certificates)] -> RuleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Type -> (ConcatArg, Certificates) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
Type -> (ConcatArg, Certificates) -> m VName
fromConcatArg Type
elem_type) [(ConcatArg, Certificates)]
ys
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
0 VName
y' [VName]
ys' SubExp
outer_w
  where
    -- If we fuse so much that there is only a single input left, then
    -- it must have the right size.
    forSingleArray :: [(ConcatArg, Certificates)] -> [(ConcatArg, Certificates)]
forSingleArray [(ArgReplicate [SubExp]
_ SubExp
v, Certificates
cs)] =
      [([SubExp] -> SubExp -> ConcatArg
ArgReplicate [SubExp
outer_w] SubExp
v, Certificates
cs)]
    forSingleArray [(ConcatArg, Certificates)]
ys = [(ConcatArg, Certificates)]
ys
simplifyConcat (SymbolTable rep, UsageTable)
_ Pattern rep
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

ruleBasicOp :: BinderOps rep => TopDownRuleBasicOp rep
ruleBasicOp :: forall rep. BinderOps rep => TopDownRuleBasicOp rep
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux BasicOp
op
  | Just (BasicOp
op', Certificates
cs) <- VarLookup rep
-> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)
forall rep.
VarLookup rep
-> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)
applySimpleRules VarLookup rep
defOf TypeLookup
seType BasicOp
op =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec rep) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpDec rep)
aux) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp BasicOp
op'
  where
    defOf :: VarLookup rep
defOf = (VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
`ST.lookupExp` TopDown rep
vtable)
    seType :: TypeLookup
seType (Var VName
v) = VName -> TopDown rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v TopDown rep
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
_ (Update VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch {}, Certificates
_) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
-- If we are writing a single-element slice from some array, and the
-- element of that array can be computed as a PrimExp based on the
-- index, let's just write that instead.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Update VName
src [DimSlice SubExp
i SubExp
n SubExp
s] (Var VName
v))
  | SubExp -> Bool
isCt1 SubExp
n,
    SubExp -> Bool
isCt1 SubExp
s,
    Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown rep -> Maybe Indexed
forall rep.
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0] TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      SubExp
e' <- String -> PrimExp VName -> RuleM rep SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
src [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i] SubExp
e'
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
_ (Update VName
dest Slice SubExp
destis (Var VName
v))
  | Just (ExpT rep
e, Certificates
_) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable,
    ExpT rep -> Bool
arrayFrom ExpT rep
e =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where
    arrayFrom :: ExpT rep -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
      | Just (ExpT rep
e', Certificates
_) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
copy_v TopDown rep
vtable =
        ExpT rep -> Bool
arrayFrom ExpT rep
e'
    arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
      VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
    arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
      | Just (Replicate Shape
dest_shape SubExp
dest_se, Certificates
_) <- VName -> TopDown rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest TopDown rep
vtable,
        SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
        Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
        Bool
True
    arrayFrom ExpT rep
_ =
      Bool
False
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
_ (Update VName
dest Slice SubExp
is SubExp
se)
  | Just Type
dest_t <- VName -> TopDown rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
dest TopDown rep
vtable,
    Shape -> Slice SubExp -> Bool
isFullSlice (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t) Slice SubExp
is = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
    case SubExp
se of
      Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
        VName
v_reshaped <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew ([SubExp] -> ShapeChange SubExp) -> [SubExp] -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
dest_t) VName
v
        Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_reshaped
      SubExp
_ -> Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType Type
dest_t
ruleBasicOp TopDown rep
vtable Pattern rep
pat (StmAux Certificates
cs1 Attrs
attrs ExpDec rep
_) (Update VName
dest1 Slice SubExp
is1 (Var VName
v1))
  | Just (Update VName
dest2 Slice SubExp
is2 SubExp
se2, Certificates
cs2) <- VName -> TopDown rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    Just (Copy VName
v3, Certificates
cs3) <- VName -> TopDown rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
dest2 TopDown rep
vtable,
    Just (Index VName
v4 Slice SubExp
is4, Certificates
cs4) <- VName -> TopDown rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v3 TopDown rep
vtable,
    Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1,
    VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
cs1 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs2 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs3 Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
cs4) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
        Slice SubExp
is5 <- Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp))
-> Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is2)
        Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  where
    simplifyWith :: SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith (Var VName
v) SubExp
x
      | Just Stm rep
bnd <- VName -> TopDown rep -> Maybe (Stm rep)
forall rep. VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
        If SubExp
p BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_ <- Stm rep -> ExpT rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
bnd,
        Just (SubExp
y, SubExp
z) <-
          VName
-> Pattern rep -> BodyT rep -> BodyT rep -> Maybe (SubExp, SubExp)
forall {dec} {rep} {rep}.
VName
-> PatternT dec -> BodyT rep -> BodyT rep -> Maybe (SubExp, SubExp)
returns VName
v (Stm rep -> Pattern rep
forall rep. Stm rep -> Pattern rep
stmPattern Stm rep
bnd) BodyT rep
tbranch BodyT rep
fbranch,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Names
forall rep. Body rep -> Names
boundInBody BodyT rep
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ BodyT rep -> Names
forall rep. Body rep -> Names
boundInBody BodyT rep
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM rep () -> Maybe (RuleM rep ())
forall a. a -> Maybe a
Just (RuleM rep () -> Maybe (RuleM rep ()))
-> RuleM rep () -> Maybe (RuleM rep ())
forall a b. (a -> b) -> a -> b
$ do
        SubExp
eq_x_y <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
        SubExp
eq_x_z <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
        SubExp
p_and_eq_x_y <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
        SubExp
not_p <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"not_p" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
        SubExp
not_p_and_eq_x_z <-
          String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
        Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
    simplifyWith SubExp
_ SubExp
_ =
      Maybe (RuleM rep ())
forall a. Maybe a
Nothing

    returns :: VName
-> PatternT dec -> BodyT rep -> BodyT rep -> Maybe (SubExp, SubExp)
returns VName
v PatternT dec
ifpat BodyT rep
tbranch BodyT rep
fbranch =
      ((PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElemT dec, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> Maybe (PatElemT dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
        ((PatElemT dec, (SubExp, SubExp)) -> Bool)
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((PatElemT dec, (SubExp, SubExp)) -> VName)
-> (PatElemT dec, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT dec -> VName)
-> ((PatElemT dec, (SubExp, SubExp)) -> PatElemT dec)
-> (PatElemT dec, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT dec, (SubExp, SubExp)) -> PatElemT dec
forall a b. (a, b) -> a
fst) ([(PatElemT dec, (SubExp, SubExp))]
 -> Maybe (PatElemT dec, (SubExp, SubExp)))
-> [(PatElemT dec, (SubExp, SubExp))]
-> Maybe (PatElemT dec, (SubExp, SubExp))
forall a b. (a -> b) -> a -> b
$
          [PatElemT dec]
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT dec -> [PatElemT dec]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT dec
ifpat) ([(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElemT dec, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
            [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult BodyT rep
tbranch) (BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult BodyT rep
fbranch)
ruleBasicOp TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) se :: SubExp
se@Constant {}) =
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) (Var VName
v)) = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
  Type
v_t <- VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
  Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$
      if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
v_t
        then SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        else VName -> BasicOp
Copy VName
v
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
_ (Replicate Shape
shape (Var VName
v))
  | Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certificates
cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Shape
shape Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Shape
shape2) SubExp
se
ruleBasicOp TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (ArrayLit (SubExp
se : [SubExp]
ses) Type
_)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se) [SubExp]
ses =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      let n :: SubExp
n = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1 :: Int64)
       in Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Index VName
idd Slice SubExp
slice)
  | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
    Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
idd2), Certificates
idd_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
idd TopDown rep
vtable,
    ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      case ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape of
        Just [SubExp]
_ ->
          Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
        Maybe [SubExp]
Nothing -> do
          -- Linearise indices and map to old index space.
          [SubExp]
oldshape <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> RuleM rep Type -> RuleM rep [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
idd2
          let new_inds :: [TPrimExp Int64 VName]
new_inds =
                [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
oldshape)
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
                  ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
inds)
          [SubExp]
new_inds' <-
            (TPrimExp Int64 VName -> RuleM rep SubExp)
-> [TPrimExp Int64 VName] -> RuleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> RuleM rep SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index") [TPrimExp Int64 VName]
new_inds
          Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
idd_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
idd2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'
ruleBasicOp TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (BinOp (Pow IntType
t) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
t Integer
2 =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Shl IntType
t) (IntType -> Integer -> SubExp
intConst IntType
t Integer
1) SubExp
e2
-- Handle identity permutation.
ruleBasicOp TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (Rearrange [Int]
perm VName
v)
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certificates
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable =
    -- Rearranging a rearranging: compose the permutations.
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certificates
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable,
    Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certificates
v2_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v2 TopDown rep
vtable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let offsets' :: [SubExp]
offsets' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
    VName
rearrange_rotate <- String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rearrange_rotate" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
v3
    Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate

-- Rearranging a replicate where the outer dimension is left untouched.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v1)
  | Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certificates
v1_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v1 TopDown rep
vtable,
    Int
num_dims <- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
    ([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
    [Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0 .. [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v1_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
          SubExp
v <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$
              BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
          Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
dims SubExp
v

-- A zero-rotation is identity.
ruleBasicOp TopDown rep
_ Pattern rep
pat StmAux (ExpDec rep)
_ (Rotate [SubExp]
offsets VName
v)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Rotate [SubExp]
offsets VName
v)
  | Just (BasicOp (Rearrange [Int]
perm VName
v2), Certificates
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable,
    Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certificates
v2_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v2 TopDown rep
vtable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let offsets2' :: [SubExp]
offsets2' = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
        addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"summed_offset" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y
    [SubExp]
offsets' <- (SubExp -> SubExp -> RuleM rep SubExp)
-> [SubExp] -> [SubExp] -> RuleM rep [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM rep SubExp
forall {m :: * -> *}. MonadBinder m => SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
    VName
rotate_rearrange <-
      StmAux (ExpDec rep) -> RuleM rep VName -> RuleM rep VName
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep VName -> RuleM rep VName)
-> RuleM rep VName -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rotate_rearrange" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v3
    Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (Certificates
v_cs Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<> Certificates
v2_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
rotate_rearrange

-- Combining Rotates.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Rotate [SubExp]
offsets1 VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certificates
v_cs) <- VName -> TopDown rep -> Maybe (ExpT rep, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (Exp rep, Certificates)
ST.lookupExp VName
v TopDown rep
vtable = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
offsets <- (SubExp -> SubExp -> RuleM rep SubExp)
-> [SubExp] -> [SubExp] -> RuleM rep [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> RuleM rep SubExp
forall {m :: * -> *}. MonadBinder m => SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
    Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
v_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets VName
v2
  where
    add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"offset" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y

-- If we see an Update with a scalar where the value to be written is
-- the result of indexing some other array, then we convert it into an
-- Update with a slice of that array.  This matters when the arrays
-- are far away (on the GPU, say), because it avoids a copy of the
-- scalar to and from the host.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (Update VName
arr_x Slice SubExp
slice_x (Var VName
v))
  | Just [SubExp]
_ <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice_x,
    Just (Index VName
arr_y Slice SubExp
slice_y, Certificates
cs_y) <- VName -> TopDown rep -> Maybe (BasicOp, Certificates)
forall rep.
VName -> SymbolTable rep -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
v TopDown rep
vtable,
    VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
arr_y TopDown rep
vtable,
    -- XXX: we should check for proper aliasing here instead.
    VName
arr_y VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
arr_x,
    Just (Slice SubExp
slice_x_bef, DimFix SubExp
i, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_x,
    Just (Slice SubExp
slice_y_bef, DimFix SubExp
j, []) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice_y Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Slice SubExp
slice_y = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
    let slice_x' :: Slice SubExp
slice_x' = Slice SubExp
slice_x_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
        slice_y' :: Slice SubExp
slice_y' = Slice SubExp
slice_y_bef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
    VName
v' <- String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_slice") (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
    Certificates -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs_y (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
      StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
        Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
arr_x Slice SubExp
slice_x' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v'

-- Simplify away 0<=i when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (CmpOp CmpSle {} SubExp
x SubExp
y)
  | Constant (IntValue (Int64Value Int64
0)) <- SubExp
x,
    Var VName
v <- SubExp
y,
    Just SubExp
_ <- VName -> TopDown rep -> Maybe SubExp
forall rep. VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away i<n when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} SubExp
x SubExp
y)
  | Var VName
v <- SubExp
x,
    Just SubExp
n <- VName -> TopDown rep -> Maybe SubExp
forall rep. VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable,
    SubExp
n SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away x<0 when 'x' has been used as array size.
ruleBasicOp TopDown rep
vtable Pattern rep
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} (Var VName
x) SubExp
y)
  | SubExp -> Bool
isCt0 SubExp
y,
    Bool -> (Entry rep -> Bool) -> Maybe (Entry rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry rep -> Bool
forall rep. Entry rep -> Bool
ST.entryIsSize (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x TopDown rep
vtable =
    RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBinder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pattern (Rep (RuleM rep)) -> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Rep m) -> Exp (Rep m) -> m ()
letBind Pattern rep
Pattern (Rep (RuleM rep))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT rep
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT rep) -> BasicOp -> ExpT rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False
ruleBasicOp TopDown rep
_ Pattern rep
_ StmAux (ExpDec rep)
_ BasicOp
_ =
  Rule rep
forall rep. Rule rep
Skip

topDownRules :: BinderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BinderOps rep => [TopDownRule rep]
topDownRules =
  [ RuleBasicOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (TopDown rep)
forall rep. BinderOps rep => TopDownRuleBasicOp rep
ruleBasicOp
  ]

bottomUpRules :: BinderOps rep => [BottomUpRule rep]
bottomUpRules :: forall rep. BinderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BinderOps rep => BottomUpRuleBasicOp rep
simplifyConcat
  ]

-- | A set of simplification rules for 'BasicOp's.  Includes rules
-- from "Futhark.Optimise.Simplify.Rules.Simple".
basicOpRules :: (BinderOps rep, Aliased rep) => RuleBook rep
basicOpRules :: forall rep. (BinderOps rep, Aliased rep) => RuleBook rep
basicOpRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BinderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BinderOps rep => [BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BinderOps rep, Aliased rep) => RuleBook rep
loopRules