{-# LANGUAGE FlexibleContexts #-}

-- | Defines simplification functions for 'PrimExp's.
module Futhark.Analysis.PrimExp.Simplify (simplifyPrimExp, simplifyExtPrimExp) where

import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.Optimise.Simplify.Engine as Engine

-- | Simplify a 'PrimExp', including copy propagation.  If a 'LeafExp'
-- refers to a name that is a 'Constant', the node turns into a
-- 'ValueExp'.
simplifyPrimExp ::
  SimplifiableRep rep =>
  PrimExp VName ->
  SimpleM rep (PrimExp VName)
simplifyPrimExp :: PrimExp VName -> SimpleM rep (PrimExp VName)
simplifyPrimExp = (VName -> PrimType -> SimpleM rep (PrimExp VName))
-> PrimExp VName -> SimpleM rep (PrimExp VName)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp VName -> PrimType -> SimpleM rep (PrimExp VName)
forall rep.
(ASTRep rep, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep),
 Simplifiable (RetType rep), Simplifiable (BranchType rep),
 TraverseOpStms (Wise rep), CanBeWise (Op rep),
 IndexOp (OpWithWisdom (Op rep)), BuilderOps (Wise rep)) =>
VName -> PrimType -> SimpleM rep (PrimExp VName)
onLeaf
  where
    onLeaf :: VName -> PrimType -> SimpleM rep (PrimExp VName)
onLeaf VName
v PrimType
pt = do
      SubExp
se <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
se of
        Var VName
v' -> PrimExp VName -> SimpleM rep (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> SimpleM rep (PrimExp VName))
-> PrimExp VName -> SimpleM rep (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v' PrimType
pt
        Constant PrimValue
pv -> PrimExp VName -> SimpleM rep (PrimExp VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp VName -> SimpleM rep (PrimExp VName))
-> PrimExp VName -> SimpleM rep (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv

-- | Like 'simplifyPrimExp', but where leaves may be 'Ext's.
simplifyExtPrimExp ::
  SimplifiableRep rep =>
  PrimExp (Ext VName) ->
  SimpleM rep (PrimExp (Ext VName))
simplifyExtPrimExp :: PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
simplifyExtPrimExp = (Ext VName -> PrimType -> SimpleM rep (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp Ext VName -> PrimType -> SimpleM rep (PrimExp (Ext VName))
forall rep.
(ASTRep rep, Simplifiable (LetDec rep),
 Simplifiable (FParamInfo rep), Simplifiable (LParamInfo rep),
 Simplifiable (RetType rep), Simplifiable (BranchType rep),
 TraverseOpStms (Wise rep), CanBeWise (Op rep),
 IndexOp (OpWithWisdom (Op rep)), BuilderOps (Wise rep)) =>
Ext VName -> PrimType -> SimpleM rep (PrimExp (Ext VName))
onLeaf
  where
    onLeaf :: Ext VName -> PrimType -> SimpleM rep (PrimExp (Ext VName))
onLeaf (Free VName
v) PrimType
pt = do
      SubExp
se <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
simplify (SubExp -> SimpleM rep SubExp) -> SubExp -> SimpleM rep SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
se of
        Var VName
v' -> PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
v') PrimType
pt
        Constant PrimValue
pv -> PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp (Ext VName)
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv
    onLeaf (Ext Int
i) PrimType
pt = PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM rep (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
pt

simplifyAnyPrimExp ::
  SimplifiableRep rep =>
  (a -> PrimType -> SimpleM rep (PrimExp a)) ->
  PrimExp a ->
  SimpleM rep (PrimExp a)
simplifyAnyPrimExp :: (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f (LeafExp a
v PrimType
pt) = a -> PrimType -> SimpleM rep (PrimExp a)
f a
v PrimType
pt
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
_ (ValueExp PrimValue
pv) =
  PrimExp a -> SimpleM rep (PrimExp a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp a -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp a
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f (BinOpExp BinOp
bop PrimExp a
e1 PrimExp a
e2) =
  BinOp -> PrimExp a -> PrimExp a -> PrimExp a
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
bop (PrimExp a -> PrimExp a -> PrimExp a)
-> SimpleM rep (PrimExp a) -> SimpleM rep (PrimExp a -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f PrimExp a
e1 SimpleM rep (PrimExp a -> PrimExp a)
-> SimpleM rep (PrimExp a) -> SimpleM rep (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f PrimExp a
e2
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f (CmpOpExp CmpOp
cmp PrimExp a
e1 PrimExp a
e2) =
  CmpOp -> PrimExp a -> PrimExp a -> PrimExp a
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
cmp (PrimExp a -> PrimExp a -> PrimExp a)
-> SimpleM rep (PrimExp a) -> SimpleM rep (PrimExp a -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f PrimExp a
e1 SimpleM rep (PrimExp a -> PrimExp a)
-> SimpleM rep (PrimExp a) -> SimpleM rep (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f PrimExp a
e2
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f (UnOpExp UnOp
op PrimExp a
e) =
  UnOp -> PrimExp a -> PrimExp a
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
op (PrimExp a -> PrimExp a)
-> SimpleM rep (PrimExp a) -> SimpleM rep (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f PrimExp a
e
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f (ConvOpExp ConvOp
conv PrimExp a
e) =
  ConvOp -> PrimExp a -> PrimExp a
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
conv (PrimExp a -> PrimExp a)
-> SimpleM rep (PrimExp a) -> SimpleM rep (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f PrimExp a
e
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f (FunExp String
h [PrimExp a]
args PrimType
t) =
  String -> [PrimExp a] -> PrimType -> PrimExp a
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp String
h ([PrimExp a] -> PrimType -> PrimExp a)
-> SimpleM rep [PrimExp a] -> SimpleM rep (PrimType -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp a -> SimpleM rep (PrimExp a))
-> [PrimExp a] -> SimpleM rep [PrimExp a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
forall rep a.
SimplifiableRep rep =>
(a -> PrimType -> SimpleM rep (PrimExp a))
-> PrimExp a -> SimpleM rep (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM rep (PrimExp a)
f) [PrimExp a]
args SimpleM rep (PrimType -> PrimExp a)
-> SimpleM rep PrimType -> SimpleM rep (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM rep PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t