{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

-- | Some simplification rules for t'BasicOp'.
module Futhark.Optimise.Simplify.Rules.BasicOp
  ( basicOpRules,
  )
where

import Control.Monad
import Data.List (find, foldl', isSuffixOf, sort)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe (isNothing)
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Simple

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

data ConcatArg
  = ArgArrayLit [SubExp]
  | ArgReplicate [SubExp] SubExp
  | ArgVar VName

toConcatArg :: ST.SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg :: forall rep. SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg SymbolTable rep
vtable VName
v =
  case VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
    Just (ArrayLit [SubExp]
ses Type
_, Certs
cs) ->
      ([SubExp] -> ConcatArg
ArgArrayLit [SubExp]
ses, Certs
cs)
    Just (Replicate (Shape [SubExp
d]) SubExp
se, Certs
cs) ->
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate [SubExp
d] SubExp
se, Certs
cs)
    Maybe (BasicOp, Certs)
_ ->
      (VName -> ConcatArg
ArgVar VName
v, Certs
forall a. Monoid a => a
mempty)

fromConcatArg ::
  (MonadBuilder m) =>
  Type ->
  (ConcatArg, Certs) ->
  m VName
fromConcatArg :: forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
t (ArgArrayLit [SubExp]
ses, Certs
cs) =
  Certs -> m VName -> m VName
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_lit" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp]
ses (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType Type
t
fromConcatArg Type
_ (ArgReplicate [SubExp]
ws SubExp
se, Certs
cs) = Certs -> m VName -> m VName
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$ do
  SubExp
w <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"concat_rep_w" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> m (Exp (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws)
  String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rep" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se
fromConcatArg Type
_ (ArgVar VName
v, Certs
_) =
  VName -> m VName
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

fuseConcatArg ::
  [(ConcatArg, Certs)] ->
  (ConcatArg, Certs) ->
  [(ConcatArg, Certs)]
fuseConcatArg :: [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
xs (ArgArrayLit [], Certs
_) =
  [(ConcatArg, Certs)]
xs
fuseConcatArg [(ConcatArg, Certs)]
xs (ArgReplicate [SubExp
w] SubExp
se, Certs
cs)
  | SubExp -> Bool
isCt0 SubExp
w =
      [(ConcatArg, Certs)]
xs
  | SubExp -> Bool
isCt1 SubExp
w =
      [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
xs ([SubExp] -> ConcatArg
ArgArrayLit [SubExp
se], Certs
cs)
fuseConcatArg ((ArgArrayLit [SubExp]
x_ses, Certs
x_cs) : [(ConcatArg, Certs)]
xs) (ArgArrayLit [SubExp]
y_ses, Certs
y_cs) =
  ([SubExp] -> ConcatArg
ArgArrayLit ([SubExp]
x_ses [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ses), Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
y_cs) (ConcatArg, Certs) -> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs
fuseConcatArg ((ArgReplicate [SubExp]
x_ws SubExp
x_se, Certs
x_cs) : [(ConcatArg, Certs)]
xs) (ArgReplicate [SubExp]
y_ws SubExp
y_se, Certs
y_cs)
  | SubExp
x_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y_se =
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate ([SubExp]
x_ws [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ws) SubExp
x_se, Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
y_cs) (ConcatArg, Certs) -> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs
fuseConcatArg [(ConcatArg, Certs)]
xs (ConcatArg, Certs)
y =
  (ConcatArg, Certs)
y (ConcatArg, Certs) -> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs

simplifyConcat :: (BuilderOps rep) => BottomUpRuleBasicOp rep
-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Concat Int
i (VName
x :| [VName]
xs) SubExp
new_d)
  | Just Int
r <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Maybe Type -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
x SymbolTable rep
vtable,
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1],
    Just (VName
x', Certs
x_cs) <- [Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs', [Certs]
xs_cs) <- (VName -> Maybe (VName, Certs))
-> [VName] -> Maybe ([VName], [Certs])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM ([Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm) [VName]
xs = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
      VName
concat_rearrange <-
        Certs -> RuleM rep VName -> RuleM rep VName
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> [Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
xs_cs) (RuleM rep VName -> RuleM rep VName)
-> RuleM rep VName -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$
          String -> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rearrange" (Exp (Rep (RuleM rep)) -> RuleM rep VName)
-> Exp (Rep (RuleM rep)) -> RuleM rep VName
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
              Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
0 (VName
x' VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
xs') SubExp
new_d
      Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
  where
    transposedBy :: [Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm1 VName
v =
      case VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable of
        Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certs
vcs)
          | [Int]
perm1 [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> (VName, Certs) -> Maybe (VName, Certs)
forall a. a -> Maybe a
Just (VName
v', Certs
vcs)
        Maybe (Exp rep, Certs)
_ -> Maybe (VName, Certs)
forall a. Maybe a
Nothing

-- Removing a concatenation that involves only a single array.  This
-- may be produced as a result of other simplification rules.
simplifyConcat (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Concat Int
_ (VName
x :| []) SubExp
_) =
  -- Still need a copy because Concat produces a fresh array.
  RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
x
-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Concat Int
i (VName
x :| [VName]
xs) SubExp
new_d)
  | VName
x' VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
x_cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> [Certs] -> Certs
forall a. Monoid a => [a] -> a
mconcat [Certs]
xs_cs) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
                Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
i (VName
x' VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
zs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where
    (VName
x' : [VName]
zs, Certs
x_cs) = VName -> ([VName], Certs)
isConcat VName
x
    ([[VName]]
xs', [Certs]
xs_cs) = [([VName], Certs)] -> ([[VName]], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([VName], Certs)] -> ([[VName]], [Certs]))
-> [([VName], Certs)] -> ([[VName]], [Certs])
forall a b. (a -> b) -> a -> b
$ (VName -> ([VName], Certs)) -> [VName] -> [([VName], Certs)]
forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certs)
isConcat [VName]
xs
    isConcat :: VName -> ([VName], Certs)
isConcat VName
v = case VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
      Just (Concat Int
j (VName
y :| [VName]
ys) SubExp
_, Certs
v_cs) | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys, Certs
v_cs)
      Maybe (BasicOp, Certs)
_ -> ([VName
v], Certs
forall a. Monoid a => a
mempty)

-- Removing empty arrays from concatenations.
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Concat Int
i (VName
x :| [VName]
xs) SubExp
new_d)
  | Just [Type]
ts <- (VName -> Maybe Type) -> [VName] -> Maybe [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (VName -> SymbolTable rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
`ST.lookupType` SymbolTable rep
vtable) ([VName] -> Maybe [Type]) -> [VName] -> Maybe [Type]
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
    VName
x' : [VName]
xs' <- ((VName, Type) -> VName) -> [(VName, Type)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Type) -> VName
forall a b. (a, b) -> a
fst ([(VName, Type)] -> [VName]) -> [(VName, Type)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Type) -> Bool) -> [(VName, Type)] -> [(VName, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe (PrimType, ShapeBase SubExp) -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe (PrimType, ShapeBase SubExp) -> Bool)
-> ((VName, Type) -> Maybe (PrimType, ShapeBase SubExp))
-> (VName, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Maybe (PrimType, ShapeBase SubExp)
isEmptyArray (Type -> Maybe (PrimType, ShapeBase SubExp))
-> ((VName, Type) -> Type)
-> (VName, Type)
-> Maybe (PrimType, ShapeBase SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> Type
forall a b. (a, b) -> b
snd) ([(VName, Type)] -> [(VName, Type)])
-> [(VName, Type)] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) [Type]
ts,
    [VName]
xs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
i (VName
x' VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
xs') SubExp
new_d
-- Fusing arguments to the concat when possible.  Only done when
-- concatenating along the outer dimension for now.
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Concat Int
0 (VName
x :| [VName]
xs) SubExp
outer_w)
  | -- We produce the to-be-concatenated arrays in reverse order, so
    -- reverse them back.
    (ConcatArg, Certs)
y : [(ConcatArg, Certs)]
ys <-
      [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forSingleArray ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)]
-> [(ConcatArg, Certs)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a. [a] -> [a]
reverse ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)]
-> [(ConcatArg, Certs)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([(ConcatArg, Certs)]
 -> (ConcatArg, Certs) -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)]
-> [(ConcatArg, Certs)]
-> [(ConcatArg, Certs)]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
forall a. Monoid a => a
mempty ([(ConcatArg, Certs)] -> [(ConcatArg, Certs)])
-> [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> a -> b
$
        (VName -> (ConcatArg, Certs)) -> [VName] -> [(ConcatArg, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (SymbolTable rep -> VName -> (ConcatArg, Certs)
forall rep. SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg SymbolTable rep
vtable) (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs),
    [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [(ConcatArg, Certs)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ConcatArg, Certs)]
ys =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
        Type
elem_type <- VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
        VName
y' <- Type -> (ConcatArg, Certs) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
elem_type (ConcatArg, Certs)
y
        [VName]
ys' <- ((ConcatArg, Certs) -> RuleM rep VName)
-> [(ConcatArg, Certs)] -> RuleM rep [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Type -> (ConcatArg, Certs) -> RuleM rep VName
forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
elem_type) [(ConcatArg, Certs)]
ys
        StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
0 (VName
y' VName -> [VName] -> NonEmpty VName
forall a. a -> [a] -> NonEmpty a
:| [VName]
ys') SubExp
outer_w
  where
    -- If we fuse so much that there is only a single input left, then
    -- it must have the right size.
    forSingleArray :: [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forSingleArray [(ArgReplicate [SubExp]
_ SubExp
v, Certs
cs)] =
      [([SubExp] -> SubExp -> ConcatArg
ArgReplicate [SubExp
outer_w] SubExp
v, Certs
cs)]
    forSingleArray [(ConcatArg, Certs)]
ys = [(ConcatArg, Certs)]
ys
simplifyConcat (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = Rule rep
forall rep. Rule rep
Skip

ruleBasicOp :: (BuilderOps rep) => TopDownRuleBasicOp rep
ruleBasicOp :: forall rep. BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux BasicOp
op
  | Just (BasicOp
op', Certs
cs) <- VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)
forall rep.
VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)
applySimpleRules VarLookup rep
defOf TypeLookup
seType BasicOp
op =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op'
  where
    defOf :: VarLookup rep
defOf = (VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` TopDown rep
vtable)
    seType :: TypeLookup
seType (Var VName
v) = VName -> TopDown rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v TopDown rep
vtable
    seType (Constant PrimValue
v) = Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
_ VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch {}, Certs
_) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
-- If we are writing a single-element slice from some array, and the
-- element of that array can be computed as a PrimExp based on the
-- index, let's just write that instead.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
safety VName
src (Slice [DimSlice SubExp
i SubExp
n SubExp
s]) (Var VName
v))
  | SubExp -> Bool
isCt1 SubExp
n,
    SubExp -> Bool
isCt1 SubExp
s,
    Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> [SubExp] -> TopDown rep -> Maybe Indexed
forall rep.
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0] TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ do
        SubExp
e' <- String -> PrimExp VName -> RuleM rep SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
        StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
              Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
src ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
e'
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
_ VName
dest Slice SubExp
destis (Var VName
v))
  | Just (Exp rep
e, Certs
_) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    Exp rep -> Bool
arrayFrom Exp rep
e =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where
    arrayFrom :: Exp rep -> Bool
arrayFrom (BasicOp (Replicate (Shape []) (Var VName
copy_v)))
      | Just (Exp rep
e', Certs
_) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
copy_v TopDown rep
vtable =
          Exp rep -> Bool
arrayFrom Exp rep
e'
    arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
      VName
src VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
    arrayFrom (BasicOp (Replicate ShapeBase SubExp
v_shape SubExp
v_se))
      | Just (Replicate ShapeBase SubExp
dest_shape SubExp
dest_se, Certs
_) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
dest TopDown rep
vtable,
        SubExp
v_se SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
        ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
v_shape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape =
          Bool
True
    arrayFrom Exp rep
_ =
      Bool
False
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
Unsafe VName
dest Slice SubExp
is SubExp
se)
  | Just Type
dest_t <- VName -> TopDown rep -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
dest TopDown rep
vtable,
    ShapeBase SubExp -> Slice SubExp -> Bool
isFullSlice (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t) Slice SubExp
is = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
      case SubExp
se of
        Var VName
v | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
          SubExp
v_reshaped <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_reshaped") (Exp rep -> RuleM rep SubExp)
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RuleM rep SubExp) -> BasicOp -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$
              ReshapeKind -> ShapeBase SubExp -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary (Type -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t) VName
v
          Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
forall a. Monoid a => a
mempty SubExp
v_reshaped
        SubExp
_ -> Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] (Type -> BasicOp) -> Type -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u.
TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
rowType Type
dest_t
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat (StmAux Certs
cs1 Attrs
attrs ExpDec rep
_) (Update Safety
safety1 VName
dest1 Slice SubExp
is1 (Var VName
v1))
  | Just (Update Safety
safety2 VName
dest2 Slice SubExp
is2 SubExp
se2, Certs
cs2) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    Just (Replicate (Shape []) (Var VName
v3), Certs
cs3) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
dest2 TopDown rep
vtable,
    Just (Index VName
v4 Slice SubExp
is4, Certs
cs4) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v3 TopDown rep
vtable,
    Slice SubExp
is4 Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1,
    VName
v4 VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest1 =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs1 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs2 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs3 Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> Certs
cs4) (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
          Slice SubExp
is5 <- Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp))
-> Slice (TPrimExp Int64 VName) -> RuleM rep (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is2)
          Attrs -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update (Safety -> Safety -> Safety
forall a. Ord a => a -> a -> a
max Safety
safety1 Safety
safety2) VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se1 SubExp
se2 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se2 SubExp
se1 = RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  where
    simplifyWith :: SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith (Var VName
v) SubExp
x
      | Just Stm rep
stm <- VName -> TopDown rep -> Maybe (Stm rep)
forall rep. VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
        Match [SubExp
p] [Case [Just (BoolValue Bool
True)] Body rep
tbranch] Body rep
fbranch MatchDec (BranchType rep)
_ <- Stm rep -> Exp rep
forall rep. Stm rep -> Exp rep
stmExp Stm rep
stm,
        Just (SubExp
y, SubExp
z) <-
          VName
-> Pat (LetDec rep)
-> Body rep
-> Body rep
-> Maybe (SubExp, SubExp)
forall {dec} {rep} {rep}.
VName -> Pat dec -> Body rep -> Body rep -> Maybe (SubExp, SubExp)
returns VName
v (Stm rep -> Pat (LetDec rep)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm) Body rep
tbranch Body rep
fbranch,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
tbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
y,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Body rep -> Names
forall rep. Body rep -> Names
boundInBody Body rep
fbranch Names -> Names -> Bool
`namesIntersect` SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn SubExp
z = RuleM rep () -> Maybe (RuleM rep ())
forall a. a -> Maybe a
Just (RuleM rep () -> Maybe (RuleM rep ()))
-> RuleM rep () -> Maybe (RuleM rep ())
forall a b. (a -> b) -> a -> b
$ do
          SubExp
eq_x_y <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
          SubExp
eq_x_z <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_z" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
          SubExp
p_and_eq_x_y <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
          SubExp
not_p <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"not_p" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
          SubExp
not_p_and_eq_x_z <-
            String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
          Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
            BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
              BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
    simplifyWith SubExp
_ SubExp
_ =
      Maybe (RuleM rep ())
forall a. Maybe a
Nothing

    returns :: VName -> Pat dec -> Body rep -> Body rep -> Maybe (SubExp, SubExp)
returns VName
v Pat dec
ifpat Body rep
tbranch Body rep
fbranch =
      ((PatElem dec, (SubExp, SubExp)) -> (SubExp, SubExp))
-> Maybe (PatElem dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PatElem dec, (SubExp, SubExp)) -> (SubExp, SubExp)
forall a b. (a, b) -> b
snd (Maybe (PatElem dec, (SubExp, SubExp)) -> Maybe (SubExp, SubExp))
-> ([(PatElem dec, (SubExp, SubExp))]
    -> Maybe (PatElem dec, (SubExp, SubExp)))
-> [(PatElem dec, (SubExp, SubExp))]
-> Maybe (SubExp, SubExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((PatElem dec, (SubExp, SubExp)) -> Bool)
-> [(PatElem dec, (SubExp, SubExp))]
-> Maybe (PatElem dec, (SubExp, SubExp))
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((PatElem dec, (SubExp, SubExp)) -> VName)
-> (PatElem dec, (SubExp, SubExp))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName (PatElem dec -> VName)
-> ((PatElem dec, (SubExp, SubExp)) -> PatElem dec)
-> (PatElem dec, (SubExp, SubExp))
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem dec, (SubExp, SubExp)) -> PatElem dec
forall a b. (a, b) -> a
fst) ([(PatElem dec, (SubExp, SubExp))] -> Maybe (SubExp, SubExp))
-> [(PatElem dec, (SubExp, SubExp))] -> Maybe (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
        [PatElem dec]
-> [(SubExp, SubExp)] -> [(PatElem dec, (SubExp, SubExp))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat dec -> [PatElem dec]
forall dec. Pat dec -> [PatElem dec]
patElems Pat dec
ifpat) ([(SubExp, SubExp)] -> [(PatElem dec, (SubExp, SubExp))])
-> [(SubExp, SubExp)] -> [(PatElem dec, (SubExp, SubExp))]
forall a b. (a -> b) -> a -> b
$
          [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
tbranch)) ((SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body rep
fbranch))
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate ShapeBase SubExp
_ SubExp
se)
  | [Acc {}] <- Pat (LetDec rep) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) SubExp
se)
  | [Prim PrimType
_] <- Pat (LetDec rep) -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate ShapeBase SubExp
shape (Var VName
v))
  | Just (BasicOp (Replicate ShapeBase SubExp
shape2 SubExp
se), Certs
cs) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    SubExp -> TopDown rep -> Bool
forall rep. SubExp -> SymbolTable rep -> Bool
ST.subExpAvailable SubExp
se TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate (ShapeBase SubExp
shape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape2) SubExp
se
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (ArrayLit (SubExp
se : [SubExp]
ses) Type
_)
  | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
se) [SubExp]
ses =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        let n :: SubExp
n = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
1 :: Int64)
         in Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Index VName
idd Slice SubExp
slice)
  | Just [SubExp]
inds <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
    Just (BasicOp (Reshape ReshapeKind
k ShapeBase SubExp
newshape VName
idd2), Certs
idd_cs) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
idd TopDown rep
vtable,
    ShapeBase SubExp -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        case ReshapeKind
k of
          ReshapeKind
ReshapeCoerce ->
            Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
idd_cs (RuleM rep () -> RuleM rep ())
-> (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
          ReshapeKind
ReshapeArbitrary -> do
            -- Linearise indices and map to old index space.
            [SubExp]
oldshape <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> RuleM rep Type -> RuleM rep [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> RuleM rep Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
idd2
            let new_inds :: [TPrimExp Int64 VName]
new_inds =
                  [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex
                    ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
oldshape)
                    ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
newshape)
                    ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
inds)
            [SubExp]
new_inds' <-
              (TPrimExp Int64 VName -> RuleM rep SubExp)
-> [TPrimExp Int64 VName] -> RuleM rep [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> TPrimExp Int64 VName -> RuleM rep SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index") [TPrimExp Int64 VName]
new_inds
            Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
idd_cs (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
              Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index VName
idd2 (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                    [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$
                      (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'

-- Copying an iota is pointless; just make it an iota instead.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Replicate (Shape []) (Var VName
v))
  | Just (Iota SubExp
n SubExp
x SubExp
s IntType
it, Certs
v_cs) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
            SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
x SubExp
s IntType
it
-- Handle identity permutation.
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Rearrange [Int]
perm VName
v)
  | [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certs
v_cs) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
      -- Rearranging a rearranging: compose the permutations.
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs (RuleM rep () -> RuleM rep ())
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
            [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
-- Rearranging a replicate where the outer dimension is left untouched.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v1)
  | Just (BasicOp (Replicate ShapeBase SubExp
dims (Var VName
v2)), Certs
v1_cs) <- VName -> TopDown rep -> Maybe (Exp rep, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v1 TopDown rep
vtable,
    Int
num_dims <- ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
dims,
    ([Int]
rep_perm, [Int]
rest_perm) <- Int -> [Int] -> ([Int], [Int])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
    [Int]
rep_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int
0 .. [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_perm Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v1_cs (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ do
            SubExp
v <-
              String -> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rearrange_replicate" (Exp (Rep (RuleM rep)) -> RuleM rep SubExp)
-> Exp (Rep (RuleM rep)) -> RuleM rep SubExp
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
                  [Int] -> VName -> BasicOp
Rearrange ((Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
            Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
dims SubExp
v

-- Simplify away 0<=i when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSle {} SubExp
x SubExp
y)
  | Constant (IntValue (Int64Value Int64
0)) <- SubExp
x,
    Var VName
v <- SubExp
y,
    Just SubExp
_ <- VName -> TopDown rep -> Maybe SubExp
forall rep. VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away i<n when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} SubExp
x SubExp
y)
  | Var VName
v <- SubExp
x,
    Just SubExp
n <- VName -> TopDown rep -> Maybe SubExp
forall rep. VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable,
    SubExp
n SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
y =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away x<0 when 'x' has been used as array size.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} (Var VName
x) SubExp
y)
  | SubExp -> Bool
isCt0 SubExp
y,
    Bool -> (Entry rep -> Bool) -> Maybe (Entry rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry rep -> Bool
forall rep. Entry rep -> Bool
ST.entryIsSize (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False
-- Simplify away 0<=y when 'y' has been used as array size.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSle {} SubExp
x (Var VName
y))
  | SubExp -> Bool
isCt0 SubExp
x,
    Bool -> (Entry rep -> Bool) -> Maybe (Entry rep) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Entry rep -> Bool
forall rep. Entry rep -> Bool
ST.entryIsSize (Maybe (Entry rep) -> Bool) -> Maybe (Entry rep) -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> TopDown rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
y TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
-- Remove certificates for variables whose definition already contain
-- that certificate.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (SubExp (Var VName
v))
  | [VName]
cs <- Certs -> [VName]
unCerts (Certs -> [VName]) -> Certs -> [VName]
forall a b. (a -> b) -> a -> b
$ StmAux (ExpDec rep) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux,
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
cs,
    Just [VName]
v_cs <- Certs -> [VName]
unCerts (Certs -> [VName]) -> (Stm rep -> Certs) -> Stm rep -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm rep -> Certs
forall rep. Stm rep -> Certs
stmCerts (Stm rep -> [VName]) -> Maybe (Stm rep) -> Maybe [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> TopDown rep -> Maybe (Stm rep)
forall rep. VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
    [VName]
cs' <- (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
v_cs) [VName]
cs,
    [VName]
cs' [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
/= [VName]
cs =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([VName] -> Certs
Certs [VName]
cs') (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$
        Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$
            SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
v
-- Remove UpdateAccs that contribute the neutral value, which is
-- always a no-op.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (UpdateAcc VName
acc [SubExp]
_ [SubExp]
vs)
  | Pat [PatElem (LetDec rep)
pe] <- Pat (LetDec rep)
pat,
    Acc VName
token ShapeBase SubExp
_ [Type]
_ NoUniqueness
_ <- PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe,
    Just (ShapeBase SubExp
_, [VName]
_, Just (Lambda rep
_, [SubExp]
ne)) <- Entry rep -> Maybe (WithAccInput rep)
forall rep. Entry rep -> Maybe (WithAccInput rep)
ST.entryAccInput (Entry rep -> Maybe (WithAccInput rep))
-> Maybe (Entry rep) -> Maybe (WithAccInput rep)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> TopDown rep -> Maybe (Entry rep)
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
token TopDown rep
vtable,
    [SubExp]
vs [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp]
ne =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (RuleM rep () -> RuleM rep ()) -> RuleM rep () -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> Rule rep) -> RuleM rep () -> Rule rep
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp (Rep (RuleM rep)) -> RuleM rep ())
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep (RuleM rep))
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep (RuleM rep)))
-> BasicOp -> Exp (Rep (RuleM rep))
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
-- Manifest of a a copy (or another Manifest) can be simplified to
-- manifesting the original array, if it is still available.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Manifest [Int]
perm VName
v1)
  | Just (Replicate (Shape []) (Var VName
v2), Certs
cs) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
v2 TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (BasicOp -> RuleM rep ()) -> BasicOp -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ())
-> (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule rep) -> BasicOp -> Rule rep
forall a b. (a -> b) -> a -> b
$
        [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v2
  | Just (Manifest [Int]
_ VName
v2, Certs
cs) <- VName -> TopDown rep -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    VName -> TopDown rep -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
v2 TopDown rep
vtable =
      RuleM rep () -> Rule rep
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM rep () -> Rule rep)
-> (BasicOp -> RuleM rep ()) -> BasicOp -> Rule rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StmAux (ExpDec rep) -> RuleM rep () -> RuleM rep ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux (RuleM rep () -> RuleM rep ())
-> (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> RuleM rep () -> RuleM rep ()
forall a. Certs -> RuleM rep a -> RuleM rep a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (RuleM rep () -> RuleM rep ())
-> (BasicOp -> RuleM rep ()) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (RuleM rep)))
-> Exp (Rep (RuleM rep)) -> RuleM rep ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
Pat (LetDec (Rep (RuleM rep)))
pat (Exp rep -> RuleM rep ())
-> (BasicOp -> Exp rep) -> BasicOp -> RuleM rep ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Rule rep) -> BasicOp -> Rule rep
forall a b. (a -> b) -> a -> b
$
        [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v2
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ =
  Rule rep
forall rep. Rule rep
Skip

topDownRules :: (BuilderOps rep) => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
  [ RuleBasicOp rep (TopDown rep) -> TopDownRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (TopDown rep)
forall rep. BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp
  ]

bottomUpRules :: (BuilderOps rep) => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ RuleBasicOp rep (BottomUp rep) -> BottomUpRule rep
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp rep (BottomUp rep)
forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyConcat
  ]

-- | A set of simplification rules for t'BasicOp's.  Includes rules
-- from "Futhark.Optimise.Simplify.Rules.Simple".
basicOpRules :: (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules :: forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules = [TopDownRule rep] -> [BottomUpRule rep] -> RuleBook rep
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook [TopDownRule rep]
forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules [BottomUpRule rep]
forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules RuleBook rep -> RuleBook rep -> RuleBook rep
forall a. Semigroup a => a -> a -> a
<> RuleBook rep
forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules