{-# LANGUAGE UndecidableInstances #-}

-- |
--
-- This module contains facilities for replacing variable names in
-- syntactic constructs.
module Futhark.Transform.Substitute
  ( Substitutions,
    Substitute (..),
    Substitutable,
  )
where

import Control.Monad.Identity
import Data.Map.Strict qualified as M
import Futhark.Analysis.PrimExp
import Futhark.IR.Prop.Names
import Futhark.IR.Prop.Scope
import Futhark.IR.Syntax
import Futhark.IR.Traversals

-- | The substitutions to be made are given by a mapping from names to
-- names.
type Substitutions = M.Map VName VName

-- | A type that is an instance of this class supports substitution of
-- any names contained within.
class Substitute a where
  -- | @substituteNames m e@ replaces the variable names in @e@ with
  -- new names, based on the mapping in @m@.  It is assumed that all
  -- names in @e@ are unique, i.e. there is no shadowing.
  substituteNames :: M.Map VName VName -> a -> a

instance Substitute a => Substitute [a] where
  substituteNames :: Map VName VName -> [a] -> [a]
substituteNames Map VName VName
substs = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute (Stm rep) => Substitute (Stms rep) where
  substituteNames :: Map VName VName -> Stms rep -> Stms rep
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance (Substitute a, Substitute b) => Substitute (a, b) where
  substituteNames :: Map VName VName -> (a, b) -> (a, b)
substituteNames Map VName VName
substs (a
x, b
y) =
    (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs a
x, forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs b
y)

instance (Substitute a, Substitute b, Substitute c) => Substitute (a, b, c) where
  substituteNames :: Map VName VName -> (a, b, c) -> (a, b, c)
substituteNames Map VName VName
substs (a
x, b
y, c
z) =
    ( forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs a
x,
      forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs b
y,
      forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs c
z
    )

instance (Substitute a, Substitute b, Substitute c, Substitute d) => Substitute (a, b, c, d) where
  substituteNames :: Map VName VName -> (a, b, c, d) -> (a, b, c, d)
substituteNames Map VName VName
substs (a
x, b
y, c
z, d
u) =
    ( forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs a
x,
      forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs b
y,
      forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs c
z,
      forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs d
u
    )

instance Substitute a => Substitute (Maybe a) where
  substituteNames :: Map VName VName -> Maybe a -> Maybe a
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute Bool where
  substituteNames :: Map VName VName -> Bool -> Bool
substituteNames = forall a b. a -> b -> a
const forall a. a -> a
id

instance Substitute VName where
  substituteNames :: Map VName VName -> VName -> VName
substituteNames Map VName VName
substs VName
k = forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault VName
k VName
k Map VName VName
substs

instance Substitute SubExp where
  substituteNames :: Map VName VName -> SubExp -> SubExp
substituteNames Map VName VName
substs (Var VName
v) = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
v
  substituteNames Map VName VName
_ (Constant PrimValue
v) = PrimValue -> SubExp
Constant PrimValue
v

instance Substitutable rep => Substitute (Exp rep) where
  substituteNames :: Map VName VName -> Exp rep -> Exp rep
substituteNames Map VName VName
substs = forall frep trep. Mapper frep trep Identity -> Exp frep -> Exp trep
mapExp forall a b. (a -> b) -> a -> b
$ forall rep.
Substitutable rep =>
Map VName VName -> Mapper rep rep Identity
replace Map VName VName
substs

instance Substitute dec => Substitute (PatElem dec) where
  substituteNames :: Map VName VName -> PatElem dec -> PatElem dec
substituteNames Map VName VName
substs (PatElem VName
ident dec
dec) =
    forall dec. VName -> dec -> PatElem dec
PatElem (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
ident) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs dec
dec)

instance Substitute Attrs where
  substituteNames :: Map VName VName -> Attrs -> Attrs
substituteNames Map VName VName
_ Attrs
attrs = Attrs
attrs

instance Substitute dec => Substitute (StmAux dec) where
  substituteNames :: Map VName VName -> StmAux dec -> StmAux dec
substituteNames Map VName VName
substs (StmAux Certs
cs Attrs
attrs dec
dec) =
    forall dec. Certs -> Attrs -> dec -> StmAux dec
StmAux
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Attrs
attrs)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs dec
dec)

instance Substitute dec => Substitute (Param dec) where
  substituteNames :: Map VName VName -> Param dec -> Param dec
substituteNames Map VName VName
substs (Param Attrs
attrs VName
name dec
dec) =
    forall dec. Attrs -> VName -> dec -> Param dec
Param
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Attrs
attrs)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
name)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs dec
dec)

instance Substitute SubExpRes where
  substituteNames :: Map VName VName -> SubExpRes -> SubExpRes
substituteNames Map VName VName
substs (SubExpRes Certs
cs SubExp
se) =
    Certs -> SubExp -> SubExpRes
SubExpRes (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs SubExp
se)

instance Substitute dec => Substitute (Pat dec) where
  substituteNames :: Map VName VName -> Pat dec -> Pat dec
substituteNames Map VName VName
substs (Pat [PatElem dec]
xs) =
    forall dec. [PatElem dec] -> Pat dec
Pat (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [PatElem dec]
xs)

instance Substitute Certs where
  substituteNames :: Map VName VName -> Certs -> Certs
substituteNames Map VName VName
substs (Certs [VName]
cs) =
    [VName] -> Certs
Certs forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [VName]
cs

instance Substitutable rep => Substitute (Stm rep) where
  substituteNames :: Map VName VName -> Stm rep -> Stm rep
substituteNames Map VName VName
substs (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
annot Exp rep
e) =
    forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Pat (LetDec rep)
pat)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs StmAux (ExpDec rep)
annot)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Exp rep
e)

instance Substitutable rep => Substitute (Body rep) where
  substituteNames :: Map VName VName -> Body rep -> Body rep
substituteNames Map VName VName
substs (Body BodyDec rep
dec Stms rep
stms Result
res) =
    forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs BodyDec rep
dec)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms rep
stms)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Result
res)

replace :: Substitutable rep => M.Map VName VName -> Mapper rep rep Identity
replace :: forall rep.
Substitutable rep =>
Map VName VName -> Mapper rep rep Identity
replace Map VName VName
substs =
  Mapper
    { mapOnVName :: VName -> Identity VName
mapOnVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnSubExp :: SubExp -> Identity SubExp
mapOnSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnBody :: Scope rep -> Body rep -> Identity (Body rep)
mapOnBody = forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnRetType :: RetType rep -> Identity (RetType rep)
mapOnRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnBranchType :: BranchType rep -> Identity (BranchType rep)
mapOnBranchType = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnFParam :: Param (FParamInfo rep) -> Identity (Param (FParamInfo rep))
mapOnFParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnLParam :: Param (LParamInfo rep) -> Identity (Param (LParamInfo rep))
mapOnLParam = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs,
      mapOnOp :: OpC rep rep -> Identity (OpC rep rep)
mapOnOp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
    }

instance Substitute Rank where
  substituteNames :: Map VName VName -> Rank -> Rank
substituteNames Map VName VName
_ = forall a. a -> a
id

instance Substitute () where
  substituteNames :: Map VName VName -> () -> ()
substituteNames Map VName VName
_ = forall a. a -> a
id

instance Substitute (NoOp rep) where
  substituteNames :: Map VName VName -> NoOp rep -> NoOp rep
substituteNames Map VName VName
_ = forall a. a -> a
id

instance Substitute d => Substitute (ShapeBase d) where
  substituteNames :: Map VName VName -> ShapeBase d -> ShapeBase d
substituteNames Map VName VName
substs (Shape [d]
es) =
    forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs) [d]
es

instance Substitute d => Substitute (Ext d) where
  substituteNames :: Map VName VName -> Ext d -> Ext d
substituteNames Map VName VName
substs (Free d
x) = forall a. a -> Ext a
Free forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs d
x
  substituteNames Map VName VName
_ (Ext Int
x) = forall a. Int -> Ext a
Ext Int
x

instance Substitute Names where
  substituteNames :: Map VName VName -> Names -> Names
substituteNames = (VName -> VName) -> Names -> Names
mapNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames

instance Substitute PrimType where
  substituteNames :: Map VName VName -> PrimType -> PrimType
substituteNames Map VName VName
_ PrimType
t = PrimType
t

instance Substitute shape => Substitute (TypeBase shape u) where
  substituteNames :: Map VName VName -> TypeBase shape u -> TypeBase shape u
substituteNames Map VName VName
_ (Prim PrimType
et) =
    forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
  substituteNames Map VName VName
substs (Acc VName
acc Shape
ispace [Type]
ts u
u) =
    forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
acc)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ispace)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Type]
ts)
      u
u
  substituteNames Map VName VName
substs (Array PrimType
et shape
sz u
u) =
    forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs PrimType
et) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs shape
sz) u
u
  substituteNames Map VName VName
_ (Mem Space
space) =
    forall shape u. Space -> TypeBase shape u
Mem Space
space

instance Substitutable rep => Substitute (Lambda rep) where
  substituteNames :: Map VName VName -> Lambda rep -> Lambda rep
substituteNames Map VName VName
substs (Lambda [Param (LParamInfo rep)]
params Body rep
body [Type]
rettype) =
    forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs [Param (LParamInfo rep)]
params)
      (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Body rep
body)
      (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs) [Type]
rettype)

instance Substitute Ident where
  substituteNames :: Map VName VName -> Ident -> Ident
substituteNames Map VName VName
substs Ident
v =
    Ident
v
      { identName :: VName
identName = forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v,
        identType :: Type
identType = forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
      }

instance Substitute d => Substitute (DimIndex d) where
  substituteNames :: Map VName VName -> DimIndex d -> DimIndex d
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute d => Substitute (Slice d) where
  substituteNames :: Map VName VName -> Slice d -> Slice d
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute d => Substitute (FlatDimIndex d) where
  substituteNames :: Map VName VName -> FlatDimIndex d -> FlatDimIndex d
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute d => Substitute (FlatSlice d) where
  substituteNames :: Map VName VName -> FlatSlice d -> FlatSlice d
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute v => Substitute (PrimExp v) where
  substituteNames :: Map VName VName -> PrimExp v -> PrimExp v
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute v => Substitute (TPrimExp t v) where
  substituteNames :: Map VName VName -> TPrimExp t v -> TPrimExp t v
substituteNames Map VName VName
substs =
    forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

instance Substitutable rep => Substitute (NameInfo rep) where
  substituteNames :: Map VName VName -> NameInfo rep -> NameInfo rep
substituteNames Map VName VName
subst (LetName LetDec rep
dec) =
    forall rep. LetDec rep -> NameInfo rep
LetName forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst LetDec rep
dec
  substituteNames Map VName VName
subst (FParamName FParamInfo rep
dec) =
    forall rep. FParamInfo rep -> NameInfo rep
FParamName forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst FParamInfo rep
dec
  substituteNames Map VName VName
subst (LParamName LParamInfo rep
dec) =
    forall rep. LParamInfo rep -> NameInfo rep
LParamName forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst LParamInfo rep
dec
  substituteNames Map VName VName
_ (IndexName IntType
it) =
    forall rep. IntType -> NameInfo rep
IndexName IntType
it

instance Substitute FV where
  substituteNames :: Map VName VName -> FV -> FV
substituteNames Map VName VName
subst = Names -> FV
fvNames forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn

-- | Representations in which all annotations support name
-- substitution.
type Substitutable rep =
  ( RepTypes rep,
    Substitute (ExpDec rep),
    Substitute (BodyDec rep),
    Substitute (LetDec rep),
    Substitute (FParamInfo rep),
    Substitute (LParamInfo rep),
    Substitute (RetType rep),
    Substitute (BranchType rep),
    Substitute (Op rep)
  )