{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Converting back and forth between 'PrimExp's.  Use the 'ToExp'
-- instance to convert to Futhark expressions.
module Futhark.Analysis.PrimExp.Convert
  ( primExpFromExp,
    primExpFromSubExp,
    pe32,
    le32,
    pe64,
    le64,
    f32pe,
    f32le,
    f64pe,
    f64le,
    primExpFromSubExpM,
    replaceInPrimExp,
    replaceInPrimExpM,
    substituteInPrimExp,
    primExpSlice,
    subExpSlice,

    -- * Module reexport
    module Futhark.Analysis.PrimExp,
  )
where

import Control.Monad.Fail qualified as Fail
import Control.Monad.Identity
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Analysis.PrimExp
import Futhark.Construct
import Futhark.IR

instance ToExp v => ToExp (PrimExp v) where
  toExp :: forall (m :: * -> *).
MonadBuilder m =>
PrimExp v -> m (Exp (Rep m))
toExp (BinOpExp BinOp
op PrimExp v
x PrimExp v
y) =
    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"binop_x" PrimExp v
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"binop_y" PrimExp v
y)
  toExp (CmpOpExp CmpOp
op PrimExp v
x PrimExp v
y) =
    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"cmpop_x" PrimExp v
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"cmpop_y" PrimExp v
y)
  toExp (UnOpExp UnOp
op PrimExp v
x) =
    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UnOp -> SubExp -> BasicOp
UnOp UnOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"unop_x" PrimExp v
x)
  toExp (ConvOpExp ConvOp
op PrimExp v
x) =
    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"convop_x" PrimExp v
x)
  toExp (ValueExp PrimValue
v) =
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
  toExp (FunExp String
h [PrimExp v]
args PrimType
t) =
    forall {k} (rep :: k).
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply (String -> Name
nameFromString String
h)
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [(SubExp, Diet)]
args'
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [forall rt. IsRetType rt => PrimType -> rt
primRetType PrimType
t]
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (Safety
Safe, forall a. Monoid a => a
mempty, [])
    where
      args' :: m [(SubExp, Diet)]
args' = forall a b. [a] -> [b] -> [(a, b)]
zip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"apply_arg") [PrimExp v]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> [a]
repeat Diet
Observe)
  toExp (LeafExp v
v PrimType
_) =
    forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp v
v

instance ToExp v => ToExp (TPrimExp t v) where
  toExp :: forall (m :: * -> *).
MonadBuilder m =>
TPrimExp t v -> m (Exp (Rep m))
toExp = forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Convert an expression to a 'PrimExp'.  The provided function is
-- used to convert expressions that are not trivially 'PrimExp's.
-- This includes constants and variable names, which are passed as
-- t'SubExp's.
primExpFromExp ::
  (Fail.MonadFail m, RepTypes rep) =>
  (VName -> m (PrimExp v)) ->
  Exp rep ->
  m (PrimExp v)
primExpFromExp :: forall {k} (m :: * -> *) (rep :: k) v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (BinOp BinOp
op SubExp
x SubExp
y)) =
  forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
y
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (CmpOp CmpOp
op SubExp
x SubExp
y)) =
  forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
y
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (UnOp UnOp
op SubExp
x)) =
  forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (ConvOp ConvOp
op SubExp
x)) =
  forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
x
primExpFromExp VName -> m (PrimExp v)
f (BasicOp (SubExp SubExp
se)) =
  forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f SubExp
se
primExpFromExp VName -> m (PrimExp v)
f (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
ts (Safety, SrcLoc, [SrcLoc])
_)
  | Name -> Bool
isBuiltInFunction Name
fname,
    [Prim PrimType
t] <- forall a b. (a -> b) -> [a] -> [b]
map forall t. DeclExtTyped t => t -> DeclExtType
declExtTypeOf [RetType rep]
ts =
      forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp (Name -> String
nameToString Name
fname) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
primExpFromExp VName -> m (PrimExp v)
_ Exp rep
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Not a PrimExp"

-- | Like 'primExpFromExp', but for a t'SubExp'.
primExpFromSubExpM :: Applicative m => (VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM :: forall (m :: * -> *) v.
Applicative m =>
(VName -> m (PrimExp v)) -> SubExp -> m (PrimExp v)
primExpFromSubExpM VName -> m (PrimExp v)
f (Var VName
v) = VName -> m (PrimExp v)
f VName
v
primExpFromSubExpM VName -> m (PrimExp v)
_ (Constant PrimValue
v) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v

-- | Convert t'SubExp's of a given type.
primExpFromSubExp :: PrimType -> SubExp -> PrimExp VName
primExpFromSubExp :: PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (Var VName
v) = forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
t
primExpFromSubExp PrimType
_ (Constant PrimValue
v) = forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v

-- | Shorthand for constructing a 'TPrimExp' of type v'Int32'.
pe32 :: SubExp -> TPrimExp Int32 VName
pe32 :: SubExp -> TPrimExp Int32 VName
pe32 = forall v. PrimExp v -> TPrimExp Int32 v
isInt32 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32

-- | Shorthand for constructing a 'TPrimExp' of type v'Int32', from a leaf.
le32 :: a -> TPrimExp Int32 a
le32 :: forall a. a -> TPrimExp Int32 a
le32 = forall v. PrimExp v -> TPrimExp Int32 v
isInt32 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
int32

-- | Shorthand for constructing a 'TPrimExp' of type v'Int64'.
pe64 :: SubExp -> TPrimExp Int64 VName
pe64 :: SubExp -> TPrimExp Int64 VName
pe64 = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64

-- | Shorthand for constructing a 'TPrimExp' of type v'Int64', from a leaf.
le64 :: a -> TPrimExp Int64 a
le64 :: forall a. a -> TPrimExp Int64 a
le64 = forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
int64

-- | Shorthand for constructing a 'TPrimExp' of type 'Float32'.
f32pe :: SubExp -> TPrimExp Float VName
f32pe :: SubExp -> TPrimExp Float VName
f32pe = forall v. PrimExp v -> TPrimExp Float v
isF32 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
float32

-- | Shorthand for constructing a 'TPrimExp' of type v'Float32', from a leaf.
f32le :: a -> TPrimExp Float a
f32le :: forall a. a -> TPrimExp Float a
f32le = forall v. PrimExp v -> TPrimExp Float v
isF32 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
float32

-- | Shorthand for constructing a 'TPrimExp' of type v'Float64'.
f64pe :: SubExp -> TPrimExp Double VName
f64pe :: SubExp -> TPrimExp Double VName
f64pe = forall v. PrimExp v -> TPrimExp Double v
isF64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
float64

-- | Shorthand for constructing a 'TPrimExp' of type v'Float64', from a leaf.
f64le :: a -> TPrimExp Double a
f64le :: forall a. a -> TPrimExp Double a
f64le = forall v. PrimExp v -> TPrimExp Double v
isF64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp PrimType
float64

-- | Applying a monadic transformation to the leaves in a 'PrimExp'.
replaceInPrimExpM ::
  Monad m =>
  (a -> PrimType -> m (PrimExp b)) ->
  PrimExp a ->
  m (PrimExp b)
replaceInPrimExpM :: forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (LeafExp a
v PrimType
pt) =
  a -> PrimType -> m (PrimExp b)
f a
v PrimType
pt
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
_ (ValueExp PrimValue
v) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (BinOpExp BinOp
bop PrimExp a
pe1 PrimExp a
pe2) =
  forall v. PrimExp v -> PrimExp v
constFoldPrimExp
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
bop forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe2)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (CmpOpExp CmpOp
cop PrimExp a
pe1 PrimExp a
pe2) =
  forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
cop forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe2
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (UnOpExp UnOp
uop PrimExp a
pe) =
  forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
uop forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (ConvOpExp ConvOp
cop PrimExp a
pe) =
  forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
cop forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f PrimExp a
pe
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f (FunExp String
h [PrimExp a]
args PrimType
t) =
  forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp String
h forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM a -> PrimType -> m (PrimExp b)
f) [PrimExp a]
args forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t

-- | As 'replaceInPrimExpM', but in the identity monad.
replaceInPrimExp ::
  (a -> PrimType -> PrimExp b) ->
  PrimExp a ->
  PrimExp b
replaceInPrimExp :: forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp a -> PrimType -> PrimExp b
f PrimExp a
e = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM forall {f :: * -> *}.
Applicative f =>
a -> PrimType -> f (PrimExp b)
f' PrimExp a
e
  where
    f' :: a -> PrimType -> f (PrimExp b)
f' a
x PrimType
y = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ a -> PrimType -> PrimExp b
f a
x PrimType
y

-- | Substituting names in a PrimExp with other PrimExps
substituteInPrimExp ::
  Ord v =>
  M.Map v (PrimExp v) ->
  PrimExp v ->
  PrimExp v
substituteInPrimExp :: forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map v (PrimExp v)
tab = forall a b. (a -> PrimType -> PrimExp b) -> PrimExp a -> PrimExp b
replaceInPrimExp forall a b. (a -> b) -> a -> b
$ \v
v PrimType
t ->
  forall a. a -> Maybe a -> a
fromMaybe (forall v. v -> PrimType -> PrimExp v
LeafExp v
v PrimType
t) forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup v
v Map v (PrimExp v)
tab

-- | Convert a t'SubExp' slice to a 'PrimExp' slice.
primExpSlice :: Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice :: Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64

-- | Convert a 'PrimExp' slice to a t'SubExp' slice.
subExpSlice :: MonadBuilder m => Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice :: forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice = forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"slice"