{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Converting back and forth between 'PrimExp's.  Use the 'ToExp'
-- instance to convert to Futhark expressions.
module Futhark.Analysis.PrimExp.Convert
  ( primExpFromExp,
    primExpFromSubExp,
    pe32,
    le32,
    pe64,
    le64,
    f32pe,
    f32le,
    f64pe,
    f64le,
    primExpFromSubExpM,
    replaceInPrimExp,
    replaceInPrimExpM,
    substituteInPrimExp,
    primExpSlice,
    subExpSlice,

    -- * Module reexport
    module Futhark.Analysis.PrimExp,
  )
where

import qualified Control.Monad.Fail as Fail
import Control.Monad.Identity
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Analysis.PrimExp
import Futhark.Construct
import Futhark.IR

instance ToExp v => ToExp (PrimExp v) where
  toExp :: forall (m :: * -> *).
MonadBinder m =>
PrimExp v -> m (Exp (Lore m))
toExp (BinOpExp BinOp
op PrimExp v
x PrimExp v
y) =
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> m BasicOp -> m (ExpT (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op (SubExp -> SubExp -> BasicOp) -> m SubExp -> m (SubExp -> BasicOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"binop_x" PrimExp v
x m (SubExp -> BasicOp) -> m SubExp -> m BasicOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"binop_y" PrimExp v
y)
  toExp (CmpOpExp CmpOp
op PrimExp v
x PrimExp v
y) =
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> m BasicOp -> m (ExpT (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
op (SubExp -> SubExp -> BasicOp) -> m SubExp -> m (SubExp -> BasicOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"cmpop_x" PrimExp v
x m (SubExp -> BasicOp) -> m SubExp -> m BasicOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"cmpop_y" PrimExp v
y)
  toExp (UnOpExp UnOp
op PrimExp v
x) =
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> m BasicOp -> m (ExpT (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UnOp -> SubExp -> BasicOp
UnOp UnOp
op (SubExp -> BasicOp) -> m SubExp -> m BasicOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"unop_x" PrimExp v
x)
  toExp (ConvOpExp ConvOp
op PrimExp v
x) =
    BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> m BasicOp -> m (ExpT (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
op (SubExp -> BasicOp) -> m SubExp -> m BasicOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"convop_x" PrimExp v
x)
  toExp (ValueExp PrimValue
v) =
    ExpT (Lore m) -> m (ExpT (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpT (Lore m) -> m (ExpT (Lore m)))
-> ExpT (Lore m) -> m (ExpT (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
  toExp (FunExp String
h [PrimExp v]
args PrimType
t) =
    Name
-> [(SubExp, Diet)]
-> [RetType (Lore m)]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT (Lore m)
forall lore.
Name
-> [(SubExp, Diet)]
-> [RetType lore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT lore
Apply (String -> Name
nameFromString String
h) ([(SubExp, Diet)]
 -> [RetType (Lore m)]
 -> (Safety, SrcLoc, [SrcLoc])
 -> ExpT (Lore m))
-> m [(SubExp, Diet)]
-> m ([RetType (Lore m)]
      -> (Safety, SrcLoc, [SrcLoc]) -> ExpT (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [(SubExp, Diet)]
args' m ([RetType (Lore m)]
   -> (Safety, SrcLoc, [SrcLoc]) -> ExpT (Lore m))
-> m [RetType (Lore m)]
-> m ((Safety, SrcLoc, [SrcLoc]) -> ExpT (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [RetType (Lore m)] -> m [RetType (Lore m)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [PrimType -> RetType (Lore m)
forall rt. IsRetType rt => PrimType -> rt
primRetType PrimType
t]
      m ((Safety, SrcLoc, [SrcLoc]) -> ExpT (Lore m))
-> m (Safety, SrcLoc, [SrcLoc]) -> m (ExpT (Lore m))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Safety, SrcLoc, [SrcLoc]) -> m (Safety, SrcLoc, [SrcLoc])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Safety
Safe, SrcLoc
forall a. Monoid a => a
mempty, [])
    where
      args' :: m [(SubExp, Diet)]
args' = [SubExp] -> [Diet] -> [(SubExp, Diet)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([SubExp] -> [Diet] -> [(SubExp, Diet)])
-> m [SubExp] -> m ([Diet] -> [(SubExp, Diet)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp v -> m SubExp) -> [PrimExp v] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> PrimExp v -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"apply_arg") [PrimExp v]
args m ([Diet] -> [(SubExp, Diet)]) -> m [Diet] -> m [(SubExp, Diet)]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Diet] -> m [Diet]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Diet -> [Diet]
forall a. a -> [a]
repeat Diet
Observe)
  toExp (LeafExp v
v PrimType
_) =
    v -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp v
v

instance ToExp v => ToExp (TPrimExp t v) where
  toExp :: forall (m :: * -> *).
MonadBinder m =>
TPrimExp t v -> m (Exp (Lore m))
toExp = PrimExp v -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp (PrimExp v -> m (ExpT (Lore m)))
-> (TPrimExp t v -> PrimExp v) -> TPrimExp t v -> m (ExpT (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped

-- | Convert an expression to a 'PrimExp'.  The provided function is
-- used to convert expressions that are not trivially 'PrimExp's.
-- This includes constants and variable names, which are passed as
-- t'SubExp's.
primExpFromExp ::
  (Fail.MonadFail m, Decorations lore) =>
  (VName -> m (PrimExp v)) ->
  Exp lore ->
  m (PrimExp v)
primExpFromExp :: forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (BinOp BinOp
op SubExp
x SubExp
y)) =
  BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
op (PrimExp v -> PrimExp v -> PrimExp v)
-> m (PrimExp v) -> m (PrimExp v -> PrimExp v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x m (PrimExp v -> PrimExp v) -> m (PrimExp v) -> m (PrimExp v)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
y
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (CmpOp CmpOp
op SubExp
x SubExp
y)) =
  CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
op (PrimExp v -> PrimExp v -> PrimExp v)
-> m (PrimExp v) -> m (PrimExp v -> PrimExp v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x m (PrimExp v -> PrimExp v) -> m (PrimExp v) -> m (PrimExp v)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
y
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (UnOp UnOp
op SubExp
x)) =
  UnOp -> PrimExp v -> PrimExp v
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
op (PrimExp v -> PrimExp v) -> m (PrimExp v) -> m (PrimExp v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (ConvOp ConvOp
op SubExp
x)) =
  ConvOp -> PrimExp v -> PrimExp v
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
op (PrimExp v -> PrimExp v) -> m (PrimExp v) -> m (PrimExp v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (SubExp SubExp
se)) =
  (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
se
primExpFromExp VName -> m (PrimExp v)
f (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
ts (Safety, SrcLoc, [SrcLoc])
_)
  | Name -> Bool
isBuiltInFunction Name
fname,
    [Prim PrimType
t] <- (RetType lore -> DeclExtType) -> [RetType lore] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map RetType lore -> DeclExtType
forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [RetType lore]
ts =
    String -> [PrimExp v] -> PrimType -> PrimExp v
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp (Name -> String
nameToString Name
fname) ([PrimExp v] -> PrimType -> PrimExp v)
-> m [PrimExp v] -> m (PrimType -> PrimExp v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> m (PrimExp v))
-> [(SubExp, Diet)] -> m [PrimExp v]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f (SubExp -> m (PrimExp v))
-> ((SubExp, Diet) -> SubExp) -> (SubExp, Diet) -> m (PrimExp v)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args m (PrimType -> PrimExp v) -> m PrimType -> m (PrimExp v)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> m PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
primExpFromExp VName -> m (PrimExp v)
_ ExpT lore
_ = String -> m (PrimExp v)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a PrimExp"

-- | Like 'primExpFromExp', but for a t'SubExp'.
primExpFromSubExpM :: Applicative m => (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM :: forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f (Var VName
v) = VName -> m (PrimExp v)
f VName
v
primExpFromSubExpM VName -> m (PrimExp v)
_ (Constant PrimValue
v) = PrimExp v -> m (PrimExp v)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimExp v -> m (PrimExp v)) -> PrimExp v -> m (PrimExp v)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp v
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v

-- | Convert t'SubExp's of a given type.
primExpFromSubExp :: PrimType -> SubExp -> PrimExp VName
primExpFromSubExp :: PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (Var VName
v) = VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
t
primExpFromSubExp PrimType
_ (Constant PrimValue
v) = PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v

-- | Shorthand for constructing a 'TPrimExp' of type 'Int32'.
pe32 :: SubExp -> TPrimExp Int32 VName
pe32 :: SubExp -> TPrimExp Int32 VName
pe32 = PrimExp VName -> TPrimExp Int32 VName
forall v. PrimExp v -> TPrimExp Int32 v
isInt32 (PrimExp VName -> TPrimExp Int32 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int32 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32

-- | Shorthand for constructing a 'TPrimExp' of type 'Int32', from a leaf.
le32 :: a -> TPrimExp Int32 a
le32 :: forall a. a -> TPrimExp Int32 a
le32 = PrimExp a -> TPrimExp Int32 a
forall v. PrimExp v -> TPrimExp Int32 v
isInt32 (PrimExp a -> TPrimExp Int32 a)
-> (a -> PrimExp a) -> a -> TPrimExp Int32 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> PrimType -> PrimExp a) -> PrimType -> a -> PrimExp a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> PrimType -> PrimExp a
forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
int32

-- | Shorthand for constructing a 'TPrimExp' of type 'Int64'.
pe64 :: SubExp -> TPrimExp Int64 VName
pe64 :: SubExp -> TPrimExp Int64 VName
pe64 = PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64

-- | Shorthand for constructing a 'TPrimExp' of type 'Int64', from a leaf.
le64 :: a -> TPrimExp Int64 a
le64 :: forall a. a -> TPrimExp Int64 a
le64 = PrimExp a -> TPrimExp Int64 a
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp a -> TPrimExp Int64 a)
-> (a -> PrimExp a) -> a -> TPrimExp Int64 a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> PrimType -> PrimExp a) -> PrimType -> a -> PrimExp a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> PrimType -> PrimExp a
forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
int64

-- | Shorthand for constructing a 'TPrimExp' of type 'Float32'.
f32pe :: SubExp -> TPrimExp Float VName
f32pe :: SubExp -> TPrimExp Float VName
f32pe = PrimExp VName -> TPrimExp Float VName
forall v. PrimExp v -> TPrimExp Float v
isF32 (PrimExp VName -> TPrimExp Float VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Float VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
float32

-- | Shorthand for constructing a 'TPrimExp' of type 'Float32', from a leaf.
f32le :: a -> TPrimExp Float a
f32le :: forall a. a -> TPrimExp Float a
f32le = PrimExp a -> TPrimExp Float a
forall v. PrimExp v -> TPrimExp Float v
isF32 (PrimExp a -> TPrimExp Float a)
-> (a -> PrimExp a) -> a -> TPrimExp Float a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> PrimType -> PrimExp a) -> PrimType -> a -> PrimExp a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> PrimType -> PrimExp a
forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
float32

-- | Shorthand for constructing a 'TPrimExp' of type 'Float64'.
f64pe :: SubExp -> TPrimExp Double VName
f64pe :: SubExp -> TPrimExp Double VName
f64pe = PrimExp VName -> TPrimExp Double VName
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp VName -> TPrimExp Double VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Double VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
float64

-- | Shorthand for constructing a 'TPrimExp' of type 'Float64', from a leaf.
f64le :: a -> TPrimExp Double a
f64le :: forall a. a -> TPrimExp Double a
f64le = PrimExp a -> TPrimExp Double a
forall v. PrimExp v -> TPrimExp Double v
isF64 (PrimExp a -> TPrimExp Double a)
-> (a -> PrimExp a) -> a -> TPrimExp Double a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> PrimType -> PrimExp a) -> PrimType -> a -> PrimExp a
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> PrimType -> PrimExp a
forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
float64

-- | Applying a monadic transformation to the leaves in a 'PrimExp'.
replaceInPrimExpM ::
  Monad m =>
  (a -> PrimType -> m (PrimExp b)) ->
  PrimExp a ->
  m (PrimExp b)
replaceInPrimExpM :: forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (LeafExp a
v PrimType
pt) =
  a -> PrimType -> m (PrimExp b)
f a
v PrimType
pt
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
_ (ValueExp PrimValue
v) =
  PrimExp b -> m (PrimExp b)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp b -> m (PrimExp b)) -> PrimExp b -> m (PrimExp b)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp b
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (BinOpExp BinOp
bop PrimExp a
pe1 PrimExp a
pe2) =
  PrimExp b -> PrimExp b
forall v. PrimExp v -> PrimExp v
constFoldPrimExp
    (PrimExp b -> PrimExp b) -> m (PrimExp b) -> m (PrimExp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BinOp -> PrimExp b -> PrimExp b -> PrimExp b
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
bop (PrimExp b -> PrimExp b -> PrimExp b)
-> m (PrimExp b) -> m (PrimExp b -> PrimExp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe1 m (PrimExp b -> PrimExp b) -> m (PrimExp b) -> m (PrimExp b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe2)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (CmpOpExp CmpOp
cop PrimExp a
pe1 PrimExp a
pe2) =
  CmpOp -> PrimExp b -> PrimExp b -> PrimExp b
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
cop (PrimExp b -> PrimExp b -> PrimExp b)
-> m (PrimExp b) -> m (PrimExp b -> PrimExp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe1 m (PrimExp b -> PrimExp b) -> m (PrimExp b) -> m (PrimExp b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe2
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (UnOpExp UnOp
uop PrimExp a
pe) =
  UnOp -> PrimExp b -> PrimExp b
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
uop (PrimExp b -> PrimExp b) -> m (PrimExp b) -> m (PrimExp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (ConvOpExp ConvOp
cop PrimExp a
pe) =
  ConvOp -> PrimExp b -> PrimExp b
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
cop (PrimExp b -> PrimExp b) -> m (PrimExp b) -> m (PrimExp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (FunExp String
h [PrimExp a]
args PrimType
t) =
  String -> [PrimExp b] -> PrimType -> PrimExp b
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp String
h ([PrimExp b] -> PrimType -> PrimExp b)
-> m [PrimExp b] -> m (PrimType -> PrimExp b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp a -> m (PrimExp b)) -> [PrimExp a] -> m [PrimExp b]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f) [PrimExp a]
args m (PrimType -> PrimExp b) -> m PrimType -> m (PrimExp b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> m PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t

-- | As 'replaceInPrimExpM', but in the identity monad.
replaceInPrimExp ::
  (a -> PrimType -> PrimExp b) ->
  PrimExp a ->
  PrimExp b
replaceInPrimExp :: forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp a -> PrimType -> PrimExp b
f PrimExp a
e = Identity (PrimExp b) -> PrimExp b
forall a. Identity a -> a
runIdentity (Identity (PrimExp b) -> PrimExp b)
-> Identity (PrimExp b) -> PrimExp b
forall a b. (a -> b) -> a -> b
$ (a -> PrimType -> Identity (PrimExp b))
-> PrimExp a -> Identity (PrimExp b)
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> Identity (PrimExp b)
forall {m :: * -> *}. Monad m => a -> PrimType -> m (PrimExp b)
f' PrimExp a
e
  where
    f' :: a -> PrimType -> m (PrimExp b)
f' a
x PrimType
y = PrimExp b -> m (PrimExp b)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp b -> m (PrimExp b)) -> PrimExp b -> m (PrimExp b)
forall a b. (a -> b) -> a -> b
$ a -> PrimType -> PrimExp b
f a
x PrimType
y

-- | Substituting names in a PrimExp with other PrimExps
substituteInPrimExp ::
  Ord v =>
  M.Map v (PrimExp v) ->
  PrimExp v ->
  PrimExp v
substituteInPrimExp :: forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map v (PrimExp v)
tab = (v -> PrimType -> PrimExp v) -> PrimExp v -> PrimExp v
forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp ((v -> PrimType -> PrimExp v) -> PrimExp v -> PrimExp v)
-> (v -> PrimType -> PrimExp v) -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ \v
v PrimType
t ->
  PrimExp v -> Maybe (PrimExp v) -> PrimExp v
forall a. a -> Maybe a -> a
fromMaybe (v -> PrimType -> PrimExp v
forall v. v -> PrimType -> PrimExp v
LeafExp v
v PrimType
t) (Maybe (PrimExp v) -> PrimExp v) -> Maybe (PrimExp v) -> PrimExp v
forall a b. (a -> b) -> a -> b
$ v -> Map v (PrimExp v) -> Maybe (PrimExp v)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup v
v Map v (PrimExp v)
tab

-- | Convert a 'SubExp' slice to a 'PrimExp' slice.
primExpSlice :: Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice :: Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice = (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map ((DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
 -> Slice SubExp -> Slice (TPrimExp Int64 VName))
-> (DimIndex SubExp -> DimIndex (TPrimExp Int64 VName))
-> Slice SubExp
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> DimIndex SubExp -> DimIndex (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64

-- | Convert a 'PrimExp' slice to a 'SubExp' slice.
subExpSlice :: MonadBinder m => Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice :: forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice = (DimIndex (TPrimExp Int64 VName) -> m (DimIndex SubExp))
-> Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((DimIndex (TPrimExp Int64 VName) -> m (DimIndex SubExp))
 -> Slice (TPrimExp Int64 VName) -> m (Slice SubExp))
-> (DimIndex (TPrimExp Int64 VName) -> m (DimIndex SubExp))
-> Slice (TPrimExp Int64 VName)
-> m (Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> m SubExp)
-> DimIndex (TPrimExp Int64 VName) -> m (DimIndex SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((TPrimExp Int64 VName -> m SubExp)
 -> DimIndex (TPrimExp Int64 VName) -> m (DimIndex SubExp))
-> (TPrimExp Int64 VName -> m SubExp)
-> DimIndex (TPrimExp Int64 VName)
-> m (DimIndex SubExp)
forall a b. (a -> b) -> a -> b
$ String -> TPrimExp Int64 VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"slice"