{-# LANGUAGE FlexibleContexts #-}
-- | A legacy representation of scalar expressions used solely for
-- algebraic simplification.  Never use this.  Use
-- "Futhark.Analysis.PrimExp" instead.
module Futhark.Analysis.ScalExp
  ( RelOp0(..)
  , ScalExp(..)
  , scalExpType
  , scalExpSize
  , subExpToScalExp
  , toScalExp
  , expandScalExp
  , LookupVar
  , module Futhark.IR.Primitive
  )
where

import Data.List (find)
import Data.Maybe

import Futhark.IR.Primitive hiding (SQuot, SRem, SDiv, SMod, SSignum)
import Futhark.IR hiding (SQuot, SRem, SDiv, SMod, SSignum)
import qualified Futhark.IR as AST
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.Util.Pretty hiding (pretty)

-----------------------------------------------------------------
-- BINARY OPERATORS for Numbers                                --
-- Note that MOD, BAND, XOR, BOR, SHIFTR, SHIFTL not supported --
--   `a SHIFTL/SHIFTR p' can be translated if desired as as    --
--   `a * 2^p' or `a / 2^p                                     --
-----------------------------------------------------------------

-- | Relational operators.
data RelOp0 = LTH0
            | LEQ0
             deriving (RelOp0 -> RelOp0 -> Bool
(RelOp0 -> RelOp0 -> Bool)
-> (RelOp0 -> RelOp0 -> Bool) -> Eq RelOp0
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RelOp0 -> RelOp0 -> Bool
$c/= :: RelOp0 -> RelOp0 -> Bool
== :: RelOp0 -> RelOp0 -> Bool
$c== :: RelOp0 -> RelOp0 -> Bool
Eq, Eq RelOp0
Eq RelOp0
-> (RelOp0 -> RelOp0 -> Ordering)
-> (RelOp0 -> RelOp0 -> Bool)
-> (RelOp0 -> RelOp0 -> Bool)
-> (RelOp0 -> RelOp0 -> Bool)
-> (RelOp0 -> RelOp0 -> Bool)
-> (RelOp0 -> RelOp0 -> RelOp0)
-> (RelOp0 -> RelOp0 -> RelOp0)
-> Ord RelOp0
RelOp0 -> RelOp0 -> Bool
RelOp0 -> RelOp0 -> Ordering
RelOp0 -> RelOp0 -> RelOp0
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RelOp0 -> RelOp0 -> RelOp0
$cmin :: RelOp0 -> RelOp0 -> RelOp0
max :: RelOp0 -> RelOp0 -> RelOp0
$cmax :: RelOp0 -> RelOp0 -> RelOp0
>= :: RelOp0 -> RelOp0 -> Bool
$c>= :: RelOp0 -> RelOp0 -> Bool
> :: RelOp0 -> RelOp0 -> Bool
$c> :: RelOp0 -> RelOp0 -> Bool
<= :: RelOp0 -> RelOp0 -> Bool
$c<= :: RelOp0 -> RelOp0 -> Bool
< :: RelOp0 -> RelOp0 -> Bool
$c< :: RelOp0 -> RelOp0 -> Bool
compare :: RelOp0 -> RelOp0 -> Ordering
$ccompare :: RelOp0 -> RelOp0 -> Ordering
$cp1Ord :: Eq RelOp0
Ord, Int -> RelOp0
RelOp0 -> Int
RelOp0 -> [RelOp0]
RelOp0 -> RelOp0
RelOp0 -> RelOp0 -> [RelOp0]
RelOp0 -> RelOp0 -> RelOp0 -> [RelOp0]
(RelOp0 -> RelOp0)
-> (RelOp0 -> RelOp0)
-> (Int -> RelOp0)
-> (RelOp0 -> Int)
-> (RelOp0 -> [RelOp0])
-> (RelOp0 -> RelOp0 -> [RelOp0])
-> (RelOp0 -> RelOp0 -> [RelOp0])
-> (RelOp0 -> RelOp0 -> RelOp0 -> [RelOp0])
-> Enum RelOp0
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: RelOp0 -> RelOp0 -> RelOp0 -> [RelOp0]
$cenumFromThenTo :: RelOp0 -> RelOp0 -> RelOp0 -> [RelOp0]
enumFromTo :: RelOp0 -> RelOp0 -> [RelOp0]
$cenumFromTo :: RelOp0 -> RelOp0 -> [RelOp0]
enumFromThen :: RelOp0 -> RelOp0 -> [RelOp0]
$cenumFromThen :: RelOp0 -> RelOp0 -> [RelOp0]
enumFrom :: RelOp0 -> [RelOp0]
$cenumFrom :: RelOp0 -> [RelOp0]
fromEnum :: RelOp0 -> Int
$cfromEnum :: RelOp0 -> Int
toEnum :: Int -> RelOp0
$ctoEnum :: Int -> RelOp0
pred :: RelOp0 -> RelOp0
$cpred :: RelOp0 -> RelOp0
succ :: RelOp0 -> RelOp0
$csucc :: RelOp0 -> RelOp0
Enum, RelOp0
RelOp0 -> RelOp0 -> Bounded RelOp0
forall a. a -> a -> Bounded a
maxBound :: RelOp0
$cmaxBound :: RelOp0
minBound :: RelOp0
$cminBound :: RelOp0
Bounded, Int -> RelOp0 -> ShowS
[RelOp0] -> ShowS
RelOp0 -> String
(Int -> RelOp0 -> ShowS)
-> (RelOp0 -> String) -> ([RelOp0] -> ShowS) -> Show RelOp0
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RelOp0] -> ShowS
$cshowList :: [RelOp0] -> ShowS
show :: RelOp0 -> String
$cshow :: RelOp0 -> String
showsPrec :: Int -> RelOp0 -> ShowS
$cshowsPrec :: Int -> RelOp0 -> ShowS
Show)

-- | Representation of a scalar expression, which is:
--
--    (i) an algebraic expression, e.g., min(a+b, a*b),
--
--   (ii) a relational expression: a+b < 5,
--
--  (iii) a logical expression: e1 and (not (a+b>5)
data ScalExp= Val     PrimValue
            | Id      VName PrimType
            | SNeg    ScalExp
            | SNot    ScalExp
            | SAbs    ScalExp
            | SSignum ScalExp
            | SPlus   ScalExp ScalExp
            | SMinus  ScalExp ScalExp
            | STimes  ScalExp ScalExp
            | SPow    ScalExp ScalExp
            | SDiv ScalExp ScalExp
            | SMod    ScalExp ScalExp
            | SQuot   ScalExp ScalExp
            | SRem    ScalExp ScalExp
            | MaxMin  Bool   [ScalExp]
            | RelExp  RelOp0  ScalExp
            | SLogAnd ScalExp ScalExp
            | SLogOr  ScalExp ScalExp
              deriving (ScalExp -> ScalExp -> Bool
(ScalExp -> ScalExp -> Bool)
-> (ScalExp -> ScalExp -> Bool) -> Eq ScalExp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ScalExp -> ScalExp -> Bool
$c/= :: ScalExp -> ScalExp -> Bool
== :: ScalExp -> ScalExp -> Bool
$c== :: ScalExp -> ScalExp -> Bool
Eq, Eq ScalExp
Eq ScalExp
-> (ScalExp -> ScalExp -> Ordering)
-> (ScalExp -> ScalExp -> Bool)
-> (ScalExp -> ScalExp -> Bool)
-> (ScalExp -> ScalExp -> Bool)
-> (ScalExp -> ScalExp -> Bool)
-> (ScalExp -> ScalExp -> ScalExp)
-> (ScalExp -> ScalExp -> ScalExp)
-> Ord ScalExp
ScalExp -> ScalExp -> Bool
ScalExp -> ScalExp -> Ordering
ScalExp -> ScalExp -> ScalExp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ScalExp -> ScalExp -> ScalExp
$cmin :: ScalExp -> ScalExp -> ScalExp
max :: ScalExp -> ScalExp -> ScalExp
$cmax :: ScalExp -> ScalExp -> ScalExp
>= :: ScalExp -> ScalExp -> Bool
$c>= :: ScalExp -> ScalExp -> Bool
> :: ScalExp -> ScalExp -> Bool
$c> :: ScalExp -> ScalExp -> Bool
<= :: ScalExp -> ScalExp -> Bool
$c<= :: ScalExp -> ScalExp -> Bool
< :: ScalExp -> ScalExp -> Bool
$c< :: ScalExp -> ScalExp -> Bool
compare :: ScalExp -> ScalExp -> Ordering
$ccompare :: ScalExp -> ScalExp -> Ordering
$cp1Ord :: Eq ScalExp
Ord, Int -> ScalExp -> ShowS
[ScalExp] -> ShowS
ScalExp -> String
(Int -> ScalExp -> ShowS)
-> (ScalExp -> String) -> ([ScalExp] -> ShowS) -> Show ScalExp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScalExp] -> ShowS
$cshowList :: [ScalExp] -> ShowS
show :: ScalExp -> String
$cshow :: ScalExp -> String
showsPrec :: Int -> ScalExp -> ShowS
$cshowsPrec :: Int -> ScalExp -> ShowS
Show)

instance Num ScalExp where
  ScalExp
0 + :: ScalExp -> ScalExp -> ScalExp
+ ScalExp
y = ScalExp
y
  ScalExp
x + ScalExp
0 = ScalExp
x
  ScalExp
x + ScalExp
y = ScalExp -> ScalExp -> ScalExp
SPlus ScalExp
x ScalExp
y

  ScalExp
x - :: ScalExp -> ScalExp -> ScalExp
- ScalExp
0 = ScalExp
x
  ScalExp
x - ScalExp
y = ScalExp -> ScalExp -> ScalExp
SMinus ScalExp
x ScalExp
y

  ScalExp
0 * :: ScalExp -> ScalExp -> ScalExp
* ScalExp
_ = ScalExp
0
  ScalExp
_ * ScalExp
0 = ScalExp
0
  ScalExp
1 * ScalExp
y = ScalExp
y
  ScalExp
y * ScalExp
1 = ScalExp
y
  ScalExp
x * ScalExp
y = ScalExp -> ScalExp -> ScalExp
STimes ScalExp
x ScalExp
y

  abs :: ScalExp -> ScalExp
abs = ScalExp -> ScalExp
SAbs
  signum :: ScalExp -> ScalExp
signum = ScalExp -> ScalExp
SSignum
  fromInteger :: Integer -> ScalExp
fromInteger = PrimValue -> ScalExp
Val (PrimValue -> ScalExp)
-> (Integer -> PrimValue) -> Integer -> ScalExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue (IntValue -> PrimValue)
-> (Integer -> IntValue) -> Integer -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> IntValue
Int32Value (Int32 -> IntValue) -> (Integer -> Int32) -> Integer -> IntValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> Int32
forall a. Num a => Integer -> a
fromInteger -- probably not OK
  negate :: ScalExp -> ScalExp
negate = ScalExp -> ScalExp
SNeg

instance Pretty ScalExp where
  pprPrec :: Int -> ScalExp -> Doc
pprPrec Int
_ (Val PrimValue
val) = PrimValue -> Doc
forall a. Pretty a => a -> Doc
ppr PrimValue
val
  pprPrec Int
_ (Id VName
v PrimType
_) = VName -> Doc
forall a. Pretty a => a -> Doc
ppr VName
v
  pprPrec Int
_ (SNeg ScalExp
e) = String -> Doc
text String
"-" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Int -> ScalExp -> Doc
forall a. Pretty a => Int -> a -> Doc
pprPrec Int
9 ScalExp
e
  pprPrec Int
_ (SNot ScalExp
e) = String -> Doc
text String
"not" Doc -> Doc -> Doc
<+> Int -> ScalExp -> Doc
forall a. Pretty a => Int -> a -> Doc
pprPrec Int
9 ScalExp
e
  pprPrec Int
_ (SAbs ScalExp
e) = String -> Doc
text String
"abs" Doc -> Doc -> Doc
<+> Int -> ScalExp -> Doc
forall a. Pretty a => Int -> a -> Doc
pprPrec Int
9 ScalExp
e
  pprPrec Int
_ (SSignum ScalExp
e) = String -> Doc
text String
"signum" Doc -> Doc -> Doc
<+> Int -> ScalExp -> Doc
forall a. Pretty a => Int -> a -> Doc
pprPrec Int
9 ScalExp
e
  pprPrec Int
prec (SPlus ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"+" Int
4 Int
4 ScalExp
x ScalExp
y
  pprPrec Int
prec (SMinus ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"-" Int
4 Int
10 ScalExp
x ScalExp
y
  pprPrec Int
prec (SPow ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"^" Int
6 Int
6 ScalExp
x ScalExp
y
  pprPrec Int
prec (STimes ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"*" Int
5 Int
5 ScalExp
x ScalExp
y
  pprPrec Int
prec (SDiv ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"/" Int
5 Int
10 ScalExp
x ScalExp
y
  pprPrec Int
prec (SMod ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"%" Int
5 Int
10 ScalExp
x ScalExp
y
  pprPrec Int
prec (SQuot ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"//" Int
5 Int
10 ScalExp
x ScalExp
y
  pprPrec Int
prec (SRem ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"%%" Int
5 Int
10 ScalExp
x ScalExp
y
  pprPrec Int
prec (SLogOr ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"||" Int
0 Int
0 ScalExp
x ScalExp
y
  pprPrec Int
prec (SLogAnd ScalExp
x ScalExp
y) = Int -> String -> Int -> Int -> ScalExp -> ScalExp -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"&&" Int
1 Int
1 ScalExp
x ScalExp
y
  pprPrec Int
prec (RelExp RelOp0
LTH0 ScalExp
e) = Int -> String -> Int -> Int -> ScalExp -> Int -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"<" Int
2 Int
2 ScalExp
e (Int
0::Int)
  pprPrec Int
prec (RelExp RelOp0
LEQ0 ScalExp
e) = Int -> String -> Int -> Int -> ScalExp -> Int -> Doc
forall a b.
(Pretty a, Pretty b) =>
Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
prec String
"<=" Int
2 Int
2 ScalExp
e (Int
0::Int)
  pprPrec Int
_ (MaxMin Bool
True [ScalExp]
es) = String -> Doc
text String
"min" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (ScalExp -> Doc) -> [ScalExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ScalExp -> Doc
forall a. Pretty a => a -> Doc
ppr [ScalExp]
es)
  pprPrec Int
_ (MaxMin Bool
False [ScalExp]
es) = String -> Doc
text String
"max" Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
parens ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (ScalExp -> Doc) -> [ScalExp] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map ScalExp -> Doc
forall a. Pretty a => a -> Doc
ppr [ScalExp]
es)

ppBinOp :: (Pretty a, Pretty b) => Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp :: Int -> String -> Int -> Int -> a -> b -> Doc
ppBinOp Int
p String
bop Int
precedence Int
rprecedence a
x b
y =
  Bool -> Doc -> Doc
parensIf (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
precedence) (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
           Int -> a -> Doc
forall a. Pretty a => Int -> a -> Doc
pprPrec Int
precedence a
x Doc -> Doc -> Doc
<+/>
           String -> Doc
text String
bop Doc -> Doc -> Doc
<+>
           Int -> b -> Doc
forall a. Pretty a => Int -> a -> Doc
pprPrec Int
rprecedence b
y

instance Substitute ScalExp where
  substituteNames :: Map VName VName -> ScalExp -> ScalExp
substituteNames Map VName VName
subst ScalExp
e =
    case ScalExp
e of Id VName
v PrimType
t -> VName -> PrimType -> ScalExp
Id (Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst VName
v) PrimType
t
              Val PrimValue
v -> PrimValue -> ScalExp
Val PrimValue
v
              SNeg ScalExp
x -> ScalExp -> ScalExp
SNeg (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x
              SNot ScalExp
x -> ScalExp -> ScalExp
SNot (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x
              SAbs ScalExp
x -> ScalExp -> ScalExp
SAbs (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x
              SSignum ScalExp
x -> ScalExp -> ScalExp
SSignum (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x
              SPlus ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SPlus` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SMinus ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SMinus` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SPow ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SPow` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              STimes ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`STimes` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SDiv ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SDiv` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SMod ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SMod` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SQuot ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SDiv` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SRem ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SRem` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              MaxMin Bool
m [ScalExp]
es -> Bool -> [ScalExp] -> ScalExp
MaxMin Bool
m ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ (ScalExp -> ScalExp) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> [a] -> [b]
map (Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst) [ScalExp]
es
              RelExp RelOp0
r ScalExp
x -> RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
r (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x
              SLogAnd ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SLogAnd` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y
              SLogOr ScalExp
x ScalExp
y -> Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
x ScalExp -> ScalExp -> ScalExp
`SLogOr` Map VName VName -> ScalExp -> ScalExp
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst ScalExp
y

instance Rename ScalExp where
  rename :: ScalExp -> RenameM ScalExp
rename = ScalExp -> RenameM ScalExp
forall a. Substitute a => a -> RenameM a
substituteRename

-- | The type of a scalar expression.
scalExpType :: ScalExp -> PrimType
scalExpType :: ScalExp -> PrimType
scalExpType (Val PrimValue
v) = PrimValue -> PrimType
primValueType PrimValue
v
scalExpType (Id VName
_ PrimType
t) = PrimType
t
scalExpType (SNeg    ScalExp
e) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SNot    ScalExp
_) = PrimType
Bool
scalExpType (SAbs    ScalExp
e) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SSignum ScalExp
e) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SPlus   ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SMinus  ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (STimes  ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SDiv ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SMod ScalExp
e ScalExp
_)    = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SPow ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SQuot ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SRem ScalExp
e ScalExp
_) = ScalExp -> PrimType
scalExpType ScalExp
e
scalExpType (SLogAnd ScalExp
_ ScalExp
_) = PrimType
Bool
scalExpType (SLogOr  ScalExp
_ ScalExp
_) = PrimType
Bool
scalExpType (RelExp  RelOp0
_ ScalExp
_) = PrimType
Bool
scalExpType (MaxMin Bool
_ []) = IntType -> PrimType
IntType IntType
Int32 -- arbitrary and probably wrong.
scalExpType (MaxMin Bool
_ (ScalExp
e:[ScalExp]
_)) = ScalExp -> PrimType
scalExpType ScalExp
e

-- | Number of nodes in the scalar expression.
scalExpSize :: ScalExp -> Int
scalExpSize :: ScalExp -> Int
scalExpSize Val{} = Int
1
scalExpSize Id{} = Int
1
scalExpSize (SNeg    ScalExp
e) = ScalExp -> Int
scalExpSize ScalExp
e
scalExpSize (SNot    ScalExp
e) = ScalExp -> Int
scalExpSize ScalExp
e
scalExpSize (SAbs    ScalExp
e) = ScalExp -> Int
scalExpSize ScalExp
e
scalExpSize (SSignum ScalExp
e) = ScalExp -> Int
scalExpSize ScalExp
e
scalExpSize (SPlus   ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SMinus  ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (STimes  ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SDiv ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SMod ScalExp
x ScalExp
y)    = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SPow ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SQuot ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SRem ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SLogAnd ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (SLogOr  ScalExp
x ScalExp
y) = ScalExp -> Int
scalExpSize ScalExp
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ScalExp -> Int
scalExpSize ScalExp
y
scalExpSize (RelExp  RelOp0
_ ScalExp
x) = ScalExp -> Int
scalExpSize ScalExp
x
scalExpSize (MaxMin Bool
_ []) = Int
0
scalExpSize (MaxMin Bool
_ [ScalExp]
es) = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (ScalExp -> Int) -> [ScalExp] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ScalExp -> Int
scalExpSize [ScalExp]
es

-- | A function that checks whether a variable name corresponds to a
-- scalar expression.
type LookupVar = VName -> Maybe ScalExp

-- | Non-recursively convert a subexpression to a 'ScalExp'.  The
-- (scalar) type of the subexpression must be given in advance.
subExpToScalExp :: SubExp -> PrimType -> ScalExp
subExpToScalExp :: SubExp -> PrimType -> ScalExp
subExpToScalExp (Var VName
v) PrimType
t        = VName -> PrimType -> ScalExp
Id VName
v PrimType
t
subExpToScalExp (Constant PrimValue
val) PrimType
_ = PrimValue -> ScalExp
Val PrimValue
val

-- | Recursively convert an expression to a scalar expression.
toScalExp :: (HasScope t f, Monad f) =>
             LookupVar -> Exp lore -> f (Maybe ScalExp)
toScalExp :: LookupVar -> Exp lore -> f (Maybe ScalExp)
toScalExp LookupVar
look (BasicOp (SubExp (Var VName
v)))
  | Just ScalExp
se <- LookupVar
look VName
v =
    Maybe ScalExp -> f (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> f (Maybe ScalExp))
-> Maybe ScalExp -> f (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just ScalExp
se
  | Bool
otherwise = do
    Type
t <- VName -> f Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
    case Type
t of
      Prim PrimType
bt | PrimType -> Bool
typeIsOK PrimType
bt ->
        Maybe ScalExp -> f (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> f (Maybe ScalExp))
-> Maybe ScalExp -> f (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp) -> ScalExp -> Maybe ScalExp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ScalExp
Id VName
v PrimType
bt
      Type
_ ->
        Maybe ScalExp -> f (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ScalExp
forall a. Maybe a
Nothing
toScalExp LookupVar
_ (BasicOp (SubExp (Constant PrimValue
val)))
  | PrimType -> Bool
typeIsOK (PrimType -> Bool) -> PrimType -> Bool
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
val =
    Maybe ScalExp -> f (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> f (Maybe ScalExp))
-> Maybe ScalExp -> f (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp) -> ScalExp -> Maybe ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
val
toScalExp LookupVar
look (BasicOp (CmpOp (CmpSlt IntType
_) SubExp
x SubExp
y)) =
  ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp)
-> (ScalExp -> ScalExp) -> ScalExp -> Maybe ScalExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LTH0 (ScalExp -> Maybe ScalExp) -> f ScalExp -> f (Maybe ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ScalExp -> ScalExp -> ScalExp
sminus (ScalExp -> ScalExp -> ScalExp)
-> f ScalExp -> f (ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
x f (ScalExp -> ScalExp) -> f ScalExp -> f ScalExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
y)
toScalExp LookupVar
look (BasicOp (CmpOp (CmpSle IntType
_) SubExp
x SubExp
y)) =
  ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp)
-> (ScalExp -> ScalExp) -> ScalExp -> Maybe ScalExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LEQ0 (ScalExp -> Maybe ScalExp) -> f ScalExp -> f (Maybe ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ScalExp -> ScalExp -> ScalExp
sminus (ScalExp -> ScalExp -> ScalExp)
-> f ScalExp -> f (ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
x f (ScalExp -> ScalExp) -> f ScalExp -> f ScalExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
y)
toScalExp LookupVar
look (BasicOp (CmpOp (CmpEq PrimType
t) SubExp
x SubExp
y))
  | PrimType -> Bool
typeIsOK PrimType
t = do
  ScalExp
x' <- LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
x
  ScalExp
y' <- LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
y
  Maybe ScalExp -> f (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ScalExp -> f (Maybe ScalExp))
-> Maybe ScalExp -> f (Maybe ScalExp)
forall a b. (a -> b) -> a -> b
$ ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp) -> ScalExp -> Maybe ScalExp
forall a b. (a -> b) -> a -> b
$ case PrimType
t of
    PrimType
Bool ->
      ScalExp -> ScalExp -> ScalExp
SLogAnd ScalExp
x' ScalExp
y' ScalExp -> ScalExp -> ScalExp
`SLogOr` ScalExp -> ScalExp -> ScalExp
SLogAnd (ScalExp -> ScalExp
SNot ScalExp
x') (ScalExp -> ScalExp
SNot ScalExp
y')
    PrimType
_ ->
      RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LEQ0 (ScalExp
x' ScalExp -> ScalExp -> ScalExp
`sminus` ScalExp
y') ScalExp -> ScalExp -> ScalExp
`SLogAnd` RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
LEQ0 (ScalExp
y' ScalExp -> ScalExp -> ScalExp
`sminus` ScalExp
x')
toScalExp LookupVar
look (BasicOp (BinOp (Sub IntType
t Overflow
_) (Constant PrimValue
x) SubExp
y))
  | PrimType -> Bool
typeIsOK (PrimType -> Bool) -> PrimType -> Bool
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t, PrimValue -> Bool
zeroIsh PrimValue
x =
  ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp)
-> (ScalExp -> ScalExp) -> ScalExp -> Maybe ScalExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalExp -> ScalExp
SNeg (ScalExp -> Maybe ScalExp) -> f ScalExp -> f (Maybe ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
y
toScalExp LookupVar
look (BasicOp (UnOp UnOp
AST.Not SubExp
e)) =
  ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp)
-> (ScalExp -> ScalExp) -> ScalExp -> Maybe ScalExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalExp -> ScalExp
SNot (ScalExp -> Maybe ScalExp) -> f ScalExp -> f (Maybe ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
e
toScalExp LookupVar
look (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y))
  | Just ScalExp -> ScalExp -> ScalExp
f <- BinOp -> Maybe (ScalExp -> ScalExp -> ScalExp)
binOpScalExp BinOp
bop =
  ScalExp -> Maybe ScalExp
forall a. a -> Maybe a
Just (ScalExp -> Maybe ScalExp) -> f ScalExp -> f (Maybe ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ScalExp -> ScalExp -> ScalExp
f (ScalExp -> ScalExp -> ScalExp)
-> f ScalExp -> f (ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
x f (ScalExp -> ScalExp) -> f ScalExp -> f ScalExp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> LookupVar -> SubExp -> f ScalExp
forall t (f :: * -> *).
HasScope t f =>
LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look SubExp
y)

toScalExp LookupVar
_ Exp lore
_ = Maybe ScalExp -> f (Maybe ScalExp)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ScalExp
forall a. Maybe a
Nothing

typeIsOK :: PrimType -> Bool
typeIsOK :: PrimType -> Bool
typeIsOK = (PrimType -> [PrimType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` PrimType
Bool PrimType -> [PrimType] -> [PrimType]
forall a. a -> [a] -> [a]
: (IntType -> PrimType) -> [IntType] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map IntType -> PrimType
IntType [IntType]
allIntTypes)

subExpToScalExp' :: HasScope t f =>
                    LookupVar -> SubExp -> f ScalExp
subExpToScalExp' :: LookupVar -> SubExp -> f ScalExp
subExpToScalExp' LookupVar
look (Var VName
v)
  | Just ScalExp
se <- LookupVar
look VName
v =
    ScalExp -> f ScalExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure ScalExp
se
  | Bool
otherwise =
    Type -> ScalExp
forall shape u.
Pretty (TypeBase shape u) =>
TypeBase shape u -> ScalExp
withType (Type -> ScalExp) -> f Type -> f ScalExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> f Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
v
    where withType :: TypeBase shape u -> ScalExp
withType (Prim PrimType
t) =
            SubExp -> PrimType -> ScalExp
subExpToScalExp (VName -> SubExp
Var VName
v) PrimType
t
          withType TypeBase shape u
t =
            String -> ScalExp
forall a. HasCallStack => String -> a
error (String -> ScalExp) -> String -> ScalExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot create ScalExp from variable " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v String -> ShowS
forall a. [a] -> [a] -> [a]
++
            String
" of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeBase shape u -> String
forall a. Pretty a => a -> String
pretty TypeBase shape u
t
subExpToScalExp' LookupVar
_ (Constant PrimValue
val) =
  ScalExp -> f ScalExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ScalExp -> f ScalExp) -> ScalExp -> f ScalExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> ScalExp
Val PrimValue
val

-- | If you have a scalar expression that has been created with
-- incomplete symbol table information, you can use this function to
-- grow its 'Id' leaves.
expandScalExp :: LookupVar -> ScalExp -> ScalExp
expandScalExp :: LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
_ (Val PrimValue
v) = PrimValue -> ScalExp
Val PrimValue
v
expandScalExp LookupVar
look (Id VName
v PrimType
t) = ScalExp -> Maybe ScalExp -> ScalExp
forall a. a -> Maybe a -> a
fromMaybe (VName -> PrimType -> ScalExp
Id VName
v PrimType
t) (Maybe ScalExp -> ScalExp) -> Maybe ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ LookupVar
look VName
v
expandScalExp LookupVar
look (SNeg ScalExp
se) = ScalExp -> ScalExp
SNeg (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
se
expandScalExp LookupVar
look (SNot ScalExp
se) = ScalExp -> ScalExp
SNot (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
se
expandScalExp LookupVar
look (SAbs ScalExp
se) = ScalExp -> ScalExp
SAbs (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
se
expandScalExp LookupVar
look (SSignum ScalExp
se) = ScalExp -> ScalExp
SSignum (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
se
expandScalExp LookupVar
look (MaxMin Bool
b [ScalExp]
ses) = Bool -> [ScalExp] -> ScalExp
MaxMin Bool
b ([ScalExp] -> ScalExp) -> [ScalExp] -> ScalExp
forall a b. (a -> b) -> a -> b
$ (ScalExp -> ScalExp) -> [ScalExp] -> [ScalExp]
forall a b. (a -> b) -> [a] -> [b]
map (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look) [ScalExp]
ses
expandScalExp LookupVar
look (SPlus ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SPlus (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SMinus ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SMinus (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (STimes ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
STimes (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SDiv ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SDiv (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SMod ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SMod (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SQuot ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SQuot (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SRem ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SRem (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SPow ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SPow (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SLogAnd ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SLogAnd (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (SLogOr ScalExp
x ScalExp
y) = ScalExp -> ScalExp -> ScalExp
SLogOr (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x) (LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
y)
expandScalExp LookupVar
look (RelExp RelOp0
relop ScalExp
x) = RelOp0 -> ScalExp -> ScalExp
RelExp RelOp0
relop (ScalExp -> ScalExp) -> ScalExp -> ScalExp
forall a b. (a -> b) -> a -> b
$ LookupVar -> ScalExp -> ScalExp
expandScalExp LookupVar
look ScalExp
x

-- | "Smart constructor" that checks whether we are subtracting zero,
-- and if so just returns the first argument.
sminus :: ScalExp -> ScalExp -> ScalExp
sminus :: ScalExp -> ScalExp -> ScalExp
sminus ScalExp
x (Val PrimValue
v) | PrimValue -> Bool
zeroIsh PrimValue
v = ScalExp
x
sminus ScalExp
x ScalExp
y = ScalExp
x ScalExp -> ScalExp -> ScalExp
`SMinus` ScalExp
y

 -- XXX: Only integers and booleans, OK?
binOpScalExp :: BinOp -> Maybe (ScalExp -> ScalExp -> ScalExp)
binOpScalExp :: BinOp -> Maybe (ScalExp -> ScalExp -> ScalExp)
binOpScalExp BinOp
bop = ((BinOp, ScalExp -> ScalExp -> ScalExp)
 -> ScalExp -> ScalExp -> ScalExp)
-> Maybe (BinOp, ScalExp -> ScalExp -> ScalExp)
-> Maybe (ScalExp -> ScalExp -> ScalExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BinOp, ScalExp -> ScalExp -> ScalExp)
-> ScalExp -> ScalExp -> ScalExp
forall a b. (a, b) -> b
snd (Maybe (BinOp, ScalExp -> ScalExp -> ScalExp)
 -> Maybe (ScalExp -> ScalExp -> ScalExp))
-> ([(BinOp, ScalExp -> ScalExp -> ScalExp)]
    -> Maybe (BinOp, ScalExp -> ScalExp -> ScalExp))
-> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
-> Maybe (ScalExp -> ScalExp -> ScalExp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((BinOp, ScalExp -> ScalExp -> ScalExp) -> Bool)
-> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
-> Maybe (BinOp, ScalExp -> ScalExp -> ScalExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((BinOp -> BinOp -> Bool
forall a. Eq a => a -> a -> Bool
==BinOp
bop) (BinOp -> Bool)
-> ((BinOp, ScalExp -> ScalExp -> ScalExp) -> BinOp)
-> (BinOp, ScalExp -> ScalExp -> ScalExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BinOp, ScalExp -> ScalExp -> ScalExp) -> BinOp
forall a b. (a, b) -> a
fst) ([(BinOp, ScalExp -> ScalExp -> ScalExp)]
 -> Maybe (ScalExp -> ScalExp -> ScalExp))
-> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
-> Maybe (ScalExp -> ScalExp -> ScalExp)
forall a b. (a -> b) -> a -> b
$
                   (IntType -> [(BinOp, ScalExp -> ScalExp -> ScalExp)])
-> [IntType] -> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap IntType -> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
intOps [IntType]
allIntTypes [(BinOp, ScalExp -> ScalExp -> ScalExp)]
-> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
-> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
forall a. [a] -> [a] -> [a]
++
                   [ (BinOp
LogAnd, ScalExp -> ScalExp -> ScalExp
SLogAnd), (BinOp
LogOr, ScalExp -> ScalExp -> ScalExp
SLogOr) ]
  where intOps :: IntType -> [(BinOp, ScalExp -> ScalExp -> ScalExp)]
intOps IntType
t = [ (IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap, ScalExp -> ScalExp -> ScalExp
SPlus)
                   , (IntType -> Overflow -> BinOp
Sub IntType
t Overflow
OverflowWrap, ScalExp -> ScalExp -> ScalExp
SMinus)
                   , (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
OverflowWrap, ScalExp -> ScalExp -> ScalExp
STimes)
                   , (IntType -> BinOp
AST.SDiv IntType
t, ScalExp -> ScalExp -> ScalExp
SDiv)
                   , (IntType -> BinOp
AST.Pow IntType
t, ScalExp -> ScalExp -> ScalExp
SPow)
                   , (IntType -> BinOp
AST.SMax IntType
t, \ScalExp
x ScalExp
y -> Bool -> [ScalExp] -> ScalExp
MaxMin Bool
False [ScalExp
x,ScalExp
y])
                   , (IntType -> BinOp
AST.SMin IntType
t, \ScalExp
x ScalExp
y -> Bool -> [ScalExp] -> ScalExp
MaxMin Bool
True [ScalExp
x,ScalExp
y])
                   ]

instance FreeIn ScalExp where
  freeIn' :: ScalExp -> FV
freeIn' (Val   PrimValue
_) = FV
forall a. Monoid a => a
mempty
  freeIn' (Id VName
i PrimType
_) = VName -> FV
fvName VName
i
  freeIn' (SNeg  ScalExp
e) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
e
  freeIn' (SNot  ScalExp
e) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
e
  freeIn' (SAbs  ScalExp
e) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
e
  freeIn' (SSignum ScalExp
e) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
e
  freeIn' (SPlus ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SMinus ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SPow ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (STimes ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SDiv ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SMod ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SQuot ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SRem ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SLogOr ScalExp
x ScalExp
y)  = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (SLogAnd ScalExp
x ScalExp
y) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
x FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
y
  freeIn' (RelExp RelOp0
LTH0 ScalExp
e) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
e
  freeIn' (RelExp RelOp0
LEQ0 ScalExp
e) = ScalExp -> FV
forall a. FreeIn a => a -> FV
freeIn' ScalExp
e
  freeIn' (MaxMin Bool
_  [ScalExp]
es) = [ScalExp] -> FV
forall a. FreeIn a => a -> FV
freeIn' [ScalExp]
es