{-# LANGUAGE TupleSections #-}

-- | Particularly simple simplification rules.
module Futhark.Optimise.Simplify.Rules.Simple
  ( TypeLookup,
    VarLookup,
    applySimpleRules,
  )
where

import Control.Monad
import Data.List (isSuffixOf)
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR

-- | A function that, given a variable name, returns its definition.
type VarLookup lore = VName -> Maybe (Exp lore, Certificates)

-- | A function that, given a subexpression, returns its type.
type TypeLookup = SubExp -> Maybe Type

-- | A simple rule is a top-down rule that can be expressed as a pure
-- function.
type SimpleRule lore = VarLookup lore -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certificates)

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

simplifyCmpOp :: SimpleRule lore
simplifyCmpOp :: forall lore. SimpleRule lore
simplifyCmpOp VarLookup lore
_ TypeLookup
_ (CmpOp CmpOp
cmp SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$
    Bool -> PrimValue
BoolValue (Bool -> PrimValue) -> Bool -> PrimValue
forall a b. (a -> b) -> a -> b
$
      case CmpOp
cmp of
        CmpEq {} -> Bool
True
        CmpSlt {} -> Bool
False
        CmpUlt {} -> Bool
False
        CmpSle {} -> Bool
True
        CmpUle {} -> Bool
True
        FCmpLt {} -> Bool
False
        FCmpLe {} -> Bool
True
        CmpOp
CmpLlt -> Bool
False
        CmpOp
CmpLle -> Bool
True
simplifyCmpOp VarLookup lore
_ TypeLookup
_ (CmpOp CmpOp
cmp (Constant PrimValue
v1) (Constant PrimValue
v2)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> (Bool -> PrimValue) -> Bool -> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> PrimValue
BoolValue (Bool -> Maybe (BasicOp, Certificates))
-> Maybe Bool -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CmpOp -> PrimValue -> PrimValue -> Maybe Bool
doCmpOp CmpOp
cmp PrimValue
v1 PrimValue
v2
simplifyCmpOp VarLookup lore
look TypeLookup
_ (CmpOp CmpEq {} (Constant (IntValue IntValue
x)) (Var VName
v))
  | Just (BasicOp (ConvOp BToI {} SubExp
b), Certificates
cs) <- VarLookup lore
look VName
v =
    case IntValue -> Int
forall int. Integral int => IntValue -> int
valueIntegral IntValue
x :: Int of
      Int
1 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
b, Certificates
cs)
      Int
0 -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
b, Certificates
cs)
      Int
_ -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (PrimValue -> SubExp
Constant (Bool -> PrimValue
BoolValue Bool
False)), Certificates
cs)
simplifyCmpOp VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyBinOp :: SimpleRule lore
simplifyBinOp :: forall lore. SimpleRule lore
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp BinOp
op (Constant PrimValue
v1) (Constant PrimValue
v2))
  | Just PrimValue
res <- BinOp -> PrimValue -> PrimValue -> Maybe PrimValue
doBinOp BinOp
op PrimValue
v1 PrimValue
v2 =
    PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
res
simplifyBinOp VarLookup lore
look TypeLookup
_ (BinOp Add {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  -- x+(y-x) => y
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Sub {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp FAdd {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look TypeLookup
_ (BinOp sub :: BinOp
sub@(Sub IntType
t Overflow
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  -- Cases for simplifying (a+b)-b and permutations.

  -- (e1_a+e1_b)-e1_a == e1_b
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_b, Certificates
cs)
  -- (e1_a+e1_b)-e1_b == e1_a
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp Add {} SubExp
e1_a SubExp
e1_b), Certificates
cs) <- VarLookup lore
look VName
v1,
    SubExp
e1_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1_a, Certificates
cs)
  -- e2_a-(e2_a+e2_b) == 0-e2_b
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_a SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
sub (IntType -> Integer -> SubExp
intConst IntType
t Integer
0) SubExp
e2_b, Certificates
cs)
  -- e2_b-(e2_a+e2_b) == 0-e2_a
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp Add {} SubExp
e2_a SubExp
e2_b), Certificates
cs) <- VarLookup lore
look VName
v2,
    SubExp
e2_b SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
sub (IntType -> Integer -> SubExp
intConst IntType
t Integer
0) SubExp
e2_a, Certificates
cs)
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp FSub {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp Mul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp FMul {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
look TypeLookup
_ (BinOp (SMod IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp SMod {} SubExp
_ SubExp
e4), Certificates
v1_cs) <- VarLookup lore
look VName
v1,
    SubExp
e4 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
e1, Certificates
v1_cs)
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp SDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp SDivUp {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp FDiv {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp (SRem IntType
t Safety
_) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
0 :: Int)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
t (Int
1 :: Int)
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp SQuot {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e2 = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp (FPow FloatType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> SubExp
floatConst FloatType
t Double
1
  | SubExp -> Bool
isCt0 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e1 Bool -> Bool -> Bool
|| SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp (Shl IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp AShr {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp (And IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp Or {} SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
simplifyBinOp VarLookup lore
_ TypeLookup
_ (BinOp (Xor IntType
t) SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
t Integer
0
simplifyBinOp VarLookup lore
defOf TypeLookup
_ (BinOp BinOp
LogAnd SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt0 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False
  | SubExp -> Bool
isCt1 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt1 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
False, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf TypeLookup
_ (BinOp BinOp
LogOr SubExp
e1 SubExp
e2)
  | SubExp -> Bool
isCt0 SubExp
e1 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e2
  | SubExp -> Bool
isCt0 SubExp
e2 = SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | SubExp -> Bool
isCt1 SubExp
e1 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | SubExp -> Bool
isCt1 SubExp
e2 = PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> PrimValue -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
  | Var VName
v <- SubExp
e1,
    Just (BasicOp (UnOp UnOp
Not SubExp
e1'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e1' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
  | Var VName
v <- SubExp
e2,
    Just (BasicOp (UnOp UnOp
Not SubExp
e2'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    SubExp
e2' SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True, Certificates
v_cs)
simplifyBinOp VarLookup lore
defOf TypeLookup
_ (BinOp (SMax IntType
it) SubExp
e1 SubExp
e2)
  | SubExp
e1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
e1
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_2 SubExp
e2, Certificates
v1_cs)
  | Var VName
v1 <- SubExp
e1,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e1_1 SubExp
e1_2), Certificates
v1_cs) <- VarLookup lore
defOf VName
v1,
    SubExp
e1_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e2 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e1_1 SubExp
e2, Certificates
v1_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_1 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_2 SubExp
e1, Certificates
v2_cs)
  | Var VName
v2 <- SubExp
e2,
    Just (BasicOp (BinOp (SMax IntType
_) SubExp
e2_1 SubExp
e2_2), Certificates
v2_cs) <- VarLookup lore
defOf VName
v2,
    SubExp
e2_2 SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
e1 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
SMax IntType
it) SubExp
e2_1 SubExp
e1, Certificates
v2_cs)
simplifyBinOp VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes :: PrimValue -> Maybe (BasicOp, Certificates)
constRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (PrimValue -> (BasicOp, Certificates))
-> PrimValue
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (PrimValue -> BasicOp) -> PrimValue -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (PrimValue -> SubExp) -> PrimValue -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant

subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes :: SubExp -> Maybe (BasicOp, Certificates)
subExpRes = (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just ((BasicOp, Certificates) -> Maybe (BasicOp, Certificates))
-> (SubExp -> (BasicOp, Certificates))
-> SubExp
-> Maybe (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,Certificates
forall a. Monoid a => a
mempty) (BasicOp -> (BasicOp, Certificates))
-> (SubExp -> BasicOp) -> SubExp -> (BasicOp, Certificates)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

simplifyUnOp :: SimpleRule lore
simplifyUnOp :: forall lore. SimpleRule lore
simplifyUnOp VarLookup lore
_ TypeLookup
_ (UnOp UnOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< UnOp -> PrimValue -> Maybe PrimValue
doUnOp UnOp
op PrimValue
v
simplifyUnOp VarLookup lore
defOf TypeLookup
_ (UnOp UnOp
Not (Var VName
v))
  | Just (BasicOp (UnOp UnOp
Not SubExp
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> BasicOp
SubExp SubExp
v2, Certificates
v_cs)
simplifyUnOp VarLookup lore
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyConvOp :: SimpleRule lore
simplifyConvOp :: forall lore. SimpleRule lore
simplifyConvOp VarLookup lore
_ TypeLookup
_ (ConvOp ConvOp
op (Constant PrimValue
v)) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes (PrimValue -> Maybe (BasicOp, Certificates))
-> Maybe PrimValue -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConvOp -> PrimValue -> Maybe PrimValue
doConvOp ConvOp
op PrimValue
v
simplifyConvOp VarLookup lore
_ TypeLookup
_ (ConvOp ConvOp
op SubExp
se)
  | (PrimType
from, PrimType
to) <- ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op,
    PrimType
from PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
to =
    SubExp -> Maybe (BasicOp, Certificates)
subExpRes SubExp
se
simplifyConvOp VarLookup lore
lookupVar TypeLookup
_ (ConvOp (SExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar TypeLookup
_ (ConvOp (ZExt IntType
t2 IntType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ZExt IntType
t3 IntType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar TypeLookup
_ (ConvOp (SIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (SExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar TypeLookup
_ (ConvOp (UIToFP IntType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (ZExt IntType
t3 IntType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    IntType
t2 IntType -> IntType -> Bool
forall a. Ord a => a -> a -> Bool
>= IntType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
lookupVar TypeLookup
_ (ConvOp (FPConv FloatType
t2 FloatType
t1) (Var VName
v))
  | Just (BasicOp (ConvOp (FPConv FloatType
t3 FloatType
_) SubExp
se), Certificates
v_cs) <- VarLookup lore
lookupVar VName
v,
    FloatType
t2 FloatType -> FloatType -> Bool
forall a. Ord a => a -> a -> Bool
>= FloatType
t3 =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> FloatType -> ConvOp
FPConv FloatType
t3 FloatType
t1) SubExp
se, Certificates
v_cs)
simplifyConvOp VarLookup lore
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- If expression is true then just replace assertion.
simplifyAssert :: SimpleRule lore
simplifyAssert :: forall lore. SimpleRule lore
simplifyAssert VarLookup lore
_ TypeLookup
_ (Assert (Constant (BoolValue Bool
True)) ErrorMsg SubExp
_ (SrcLoc, [SrcLoc])
_) =
  PrimValue -> Maybe (BasicOp, Certificates)
constRes PrimValue
UnitValue
simplifyAssert VarLookup lore
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyIdentityReshape :: SimpleRule lore
simplifyIdentityReshape :: forall lore. SimpleRule lore
simplifyIdentityReshape VarLookup lore
_ TypeLookup
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t -- No-op reshape.
    =
    SubExp -> Maybe (BasicOp, Certificates)
subExpRes (SubExp -> Maybe (BasicOp, Certificates))
-> SubExp -> Maybe (BasicOp, Certificates)
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
simplifyIdentityReshape VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReshape :: SimpleRule lore
simplifyReshapeReshape :: forall lore. SimpleRule lore
simplifyReshapeReshape VarLookup lore
defOf TypeLookup
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Reshape ShapeChange SubExp
oldshape VName
v2), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape (ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape ShapeChange SubExp
oldshape ShapeChange SubExp
newshape) VName
v2, Certificates
v_cs)
simplifyReshapeReshape VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeScratch :: SimpleRule lore
simplifyReshapeScratch :: forall lore. SimpleRule lore
simplifyReshapeScratch VarLookup lore
defOf TypeLookup
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Scratch PrimType
bt [SubExp]
_), Certificates
v_cs) <- VarLookup lore
defOf VName
v =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch PrimType
bt ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape, Certificates
v_cs)
simplifyReshapeScratch VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeReplicate :: SimpleRule lore
simplifyReshapeReplicate :: forall lore. SimpleRule lore
simplifyReshapeReplicate VarLookup lore
defOf TypeLookup
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Replicate Shape
_ SubExp
se), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    Just Shape
oldshape <- Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Maybe Type -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType SubExp
se,
    Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
oldshape [SubExp] -> [SubExp] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
    let new :: [SubExp]
new =
          Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
oldshape) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$
            ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape
     in (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
new) SubExp
se, Certificates
v_cs)
simplifyReshapeReplicate VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simplifyReshapeIota :: SimpleRule lore
simplifyReshapeIota :: forall lore. SimpleRule lore
simplifyReshapeIota VarLookup lore
defOf TypeLookup
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just (BasicOp (Iota SubExp
_ SubExp
offset SubExp
stride IntType
it), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    [SubExp
n] <- ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
offset SubExp
stride IntType
it, Certificates
v_cs)
simplifyReshapeIota VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

reshapeSlice :: [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice :: forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice (DimFix d
i : [DimIndex d]
slice') [d]
scs =
  d -> DimIndex d
forall d. d -> DimIndex d
DimFix d
i DimIndex d -> [DimIndex d] -> [DimIndex d]
forall a. a -> [a] -> [a]
: [DimIndex d] -> [d] -> [DimIndex d]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex d]
slice' [d]
scs
reshapeSlice (DimSlice d
x d
_ d
s : [DimIndex d]
slice') (d
d : [d]
ds') =
  d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
x d
d d
s DimIndex d -> [DimIndex d] -> [DimIndex d]
forall a. a -> [a] -> [a]
: [DimIndex d] -> [d] -> [DimIndex d]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex d]
slice' [d]
ds'
reshapeSlice [DimIndex d]
_ [d]
_ = []

-- If we are size-coercing a slice, then we might as well just use a
-- different slice instead.
simplifyReshapeIndex :: SimpleRule lore
simplifyReshapeIndex :: forall lore. SimpleRule lore
simplifyReshapeIndex VarLookup lore
defOf TypeLookup
_ (Reshape ShapeChange SubExp
newshape VName
v)
  | Just [SubExp]
ds <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
    Just (BasicOp (Index VName
v' [DimIndex SubExp]
slice), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    [DimIndex SubExp]
slice' <- [DimIndex SubExp] -> [SubExp] -> [DimIndex SubExp]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex SubExp]
slice [SubExp]
ds,
    [DimIndex SubExp]
slice' [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [DimIndex SubExp]
slice =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (VName -> [DimIndex SubExp] -> BasicOp
Index VName
v' [DimIndex SubExp]
slice', Certificates
v_cs)
simplifyReshapeIndex VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- If we are updating a slice with the result of a size coercion, we
-- instead use the original array and update the slice dimensions.
simplifyUpdateReshape :: SimpleRule lore
simplifyUpdateReshape :: forall lore. SimpleRule lore
simplifyUpdateReshape VarLookup lore
defOf TypeLookup
seType (Update VName
dest [DimIndex SubExp]
slice (Var VName
v))
  | Just (BasicOp (Reshape ShapeChange SubExp
newshape VName
v'), Certificates
v_cs) <- VarLookup lore
defOf VName
v,
    Just [SubExp]
_ <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
    Just [SubExp]
ds <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
v'),
    [DimIndex SubExp]
slice' <- [DimIndex SubExp] -> [SubExp] -> [DimIndex SubExp]
forall d. [DimIndex d] -> [d] -> [DimIndex d]
reshapeSlice [DimIndex SubExp]
slice [SubExp]
ds,
    [DimIndex SubExp]
slice' [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [DimIndex SubExp]
slice =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (VName -> [DimIndex SubExp] -> SubExp -> BasicOp
Update VName
dest [DimIndex SubExp]
slice' (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v', Certificates
v_cs)
simplifyUpdateReshape VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

improveReshape :: SimpleRule lore
improveReshape :: forall lore. SimpleRule lore
improveReshape VarLookup lore
_ TypeLookup
seType (Reshape ShapeChange SubExp
newshape VName
v)
  | Just Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v,
    ShapeChange SubExp
newshape' <- [SubExp] -> ShapeChange SubExp -> ShapeChange SubExp
forall d. Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) ShapeChange SubExp
newshape,
    ShapeChange SubExp
newshape' ShapeChange SubExp -> ShapeChange SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeChange SubExp
newshape =
    (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
newshape' VName
v, Certificates
forall a. Monoid a => a
mempty)
improveReshape VarLookup lore
_ TypeLookup
_ BasicOp
_ = Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

-- | If we are copying a scratch array (possibly indirectly), just turn it into a scratch by
-- itself.
copyScratchToScratch :: SimpleRule lore
copyScratchToScratch :: forall lore. SimpleRule lore
copyScratchToScratch VarLookup lore
defOf TypeLookup
seType (Copy VName
src) = do
  Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
  if VName -> Bool
isActuallyScratch VName
src
    then (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (PrimType -> [SubExp] -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t), Certificates
forall a. Monoid a => a
mempty)
    else Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing
  where
    isActuallyScratch :: VName -> Bool
isActuallyScratch VName
v =
      case Exp lore -> Maybe BasicOp
forall lore. Exp lore -> Maybe BasicOp
asBasicOp (Exp lore -> Maybe BasicOp)
-> ((Exp lore, Certificates) -> Exp lore)
-> (Exp lore, Certificates)
-> Maybe BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp lore, Certificates) -> Exp lore
forall a b. (a, b) -> a
fst ((Exp lore, Certificates) -> Maybe BasicOp)
-> Maybe (Exp lore, Certificates) -> Maybe BasicOp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarLookup lore
defOf VName
v of
        Just Scratch {} -> Bool
True
        Just (Rearrange [Int]
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
        Just (Reshape ShapeChange SubExp
_ VName
v') -> VName -> Bool
isActuallyScratch VName
v'
        Maybe BasicOp
_ -> Bool
False
copyScratchToScratch VarLookup lore
_ TypeLookup
_ BasicOp
_ =
  Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

simpleRules :: [SimpleRule lore]
simpleRules :: forall lore. [SimpleRule lore]
simpleRules =
  [ SimpleRule lore
forall lore. SimpleRule lore
simplifyBinOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyCmpOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyUnOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyConvOp,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyAssert,
    SimpleRule lore
forall lore. SimpleRule lore
copyScratchToScratch,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyIdentityReshape,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReshape,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeScratch,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeReplicate,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIota,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyReshapeIndex,
    SimpleRule lore
forall lore. SimpleRule lore
simplifyUpdateReshape,
    SimpleRule lore
forall lore. SimpleRule lore
improveReshape
  ]

-- | Try to simplify the given 'BasicOp', returning a new 'BasicOp'
-- and certificates that it must depend on.
{-# NOINLINE applySimpleRules #-}
applySimpleRules ::
  VarLookup lore ->
  TypeLookup ->
  BasicOp ->
  Maybe (BasicOp, Certificates)
applySimpleRules :: forall lore. SimpleRule lore
applySimpleRules VarLookup lore
defOf TypeLookup
seType BasicOp
op =
  [Maybe (BasicOp, Certificates)] -> Maybe (BasicOp, Certificates)
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum [SimpleRule lore
rule VarLookup lore
defOf TypeLookup
seType BasicOp
op | SimpleRule lore
rule <- [SimpleRule lore]
forall lore. [SimpleRule lore]
simpleRules]