{-# LANGUAGE TypeFamilies #-}
{-# Language FlexibleInstances, FlexibleContexts #-}
module Futhark.Representation.AST.Attributes.Aliases
       ( vnameAliases
       , subExpAliases
       , primOpAliases
       , expAliases
       , patternAliases
       , Aliased (..)
       , AliasesOf (..)
         -- * Consumption
       , consumedInStm
       , consumedInExp
       , consumedByLambda
       -- * Extensibility
       , AliasedOp (..)
       , CanBeAliased (..)
       )
       where

import Control.Arrow (first)
import qualified Data.Kind

import Futhark.Representation.AST.Attributes (IsOp)
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes.Patterns
import Futhark.Representation.AST.Attributes.Types
import Futhark.Representation.AST.Attributes.Names

class (Annotations lore, AliasedOp (Op lore),
       AliasesOf (LetAttr lore)) => Aliased lore where
  bodyAliases :: Body lore -> [Names]
  consumedInBody :: Body lore -> Names

vnameAliases :: VName -> Names
vnameAliases :: VName -> Names
vnameAliases = VName -> Names
oneName

subExpAliases :: SubExp -> Names
subExpAliases :: SubExp -> Names
subExpAliases Constant{} = Names
forall a. Monoid a => a
mempty
subExpAliases (Var VName
v)    = VName -> Names
vnameAliases VName
v

primOpAliases :: BasicOp -> [Names]
primOpAliases :: BasicOp -> [Names]
primOpAliases (SubExp SubExp
se) = [SubExp -> Names
subExpAliases SubExp
se]
primOpAliases (Opaque SubExp
se) = [SubExp -> Names
subExpAliases SubExp
se]
primOpAliases (ArrayLit [SubExp]
_ Type
_) = [Names
forall a. Monoid a => a
mempty]
primOpAliases BinOp{} = [Names
forall a. Monoid a => a
mempty]
primOpAliases ConvOp{} = [Names
forall a. Monoid a => a
mempty]
primOpAliases CmpOp{} = [Names
forall a. Monoid a => a
mempty]
primOpAliases UnOp{} = [Names
forall a. Monoid a => a
mempty]

primOpAliases (Index VName
ident Slice SubExp
_) =
  [VName -> Names
vnameAliases VName
ident]
primOpAliases Update{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases Iota{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases Replicate{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases (Repeat [Shape]
_ Shape
_ VName
v) =
  [VName -> Names
vnameAliases VName
v]
primOpAliases Scratch{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases (Reshape ShapeChange SubExp
_ VName
e) =
  [VName -> Names
vnameAliases VName
e]
primOpAliases (Rearrange [Int]
_ VName
e) =
  [VName -> Names
vnameAliases VName
e]
primOpAliases (Rotate [SubExp]
_ VName
e) =
  [VName -> Names
vnameAliases VName
e]
primOpAliases Concat{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases Copy{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases Manifest{} =
  [Names
forall a. Monoid a => a
mempty]
primOpAliases Assert{} =
  [Names
forall a. Monoid a => a
mempty]

ifAliases :: ([Names], Names) -> ([Names], Names) -> [Names]
ifAliases :: ([Names], Names) -> ([Names], Names) -> [Names]
ifAliases ([Names]
als1,Names
cons1) ([Names]
als2,Names
cons2) =
  (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesSubtract` Names
cons) ([Names] -> [Names]) -> [Names] -> [Names]
forall a b. (a -> b) -> a -> b
$ (Names -> Names -> Names) -> [Names] -> [Names] -> [Names]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Names -> Names -> Names
forall a. Monoid a => a -> a -> a
mappend [Names]
als1 [Names]
als2
  where cons :: Names
cons = Names
cons1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
cons2

funcallAliases :: [(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases :: [(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases [(SubExp, Diet)]
args [TypeBase shape Uniqueness]
t =
  [TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
forall shape.
[TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases [TypeBase shape Uniqueness]
t [(SubExp -> Names
subExpAliases SubExp
se, Diet
d) | (SubExp
se,Diet
d) <- [(SubExp, Diet)]
args ]

expAliases :: (Aliased lore) => Exp lore -> [Names]
expAliases :: Exp lore -> [Names]
expAliases (If SubExp
_ BodyT lore
tb BodyT lore
fb IfAttr (BranchType lore)
attr) =
  Int -> [Names] -> [Names]
forall a. Int -> [a] -> [a]
drop ([Names] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Names]
all_aliases Int -> Int -> Int
forall a. Num a => a -> a -> a
- [BranchType lore] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [BranchType lore]
ts) [Names]
all_aliases
  where ts :: [BranchType lore]
ts = IfAttr (BranchType lore) -> [BranchType lore]
forall rt. IfAttr rt -> [rt]
ifReturns IfAttr (BranchType lore)
attr
        all_aliases :: [Names]
all_aliases = ([Names], Names) -> ([Names], Names) -> [Names]
ifAliases
                      (BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
tb, BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
tb)
                      (BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
fb, BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
fb)
expAliases (BasicOp BasicOp
op) = BasicOp -> [Names]
primOpAliases BasicOp
op
expAliases (DoLoop [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
valmerge LoopForm lore
_ BodyT lore
loopbody) =
  (Names -> Names) -> [Names] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (Names -> Names -> Names
`namesSubtract` Names
merge_names) [Names]
val_aliases
  where ([Names]
_ctx_aliases, [Names]
val_aliases) =
          Int -> [Names] -> ([Names], [Names])
forall a. Int -> [a] -> ([a], [a])
splitAt ([(FParam lore, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(FParam lore, SubExp)]
ctxmerge) ([Names] -> ([Names], [Names])) -> [Names] -> ([Names], [Names])
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [Names]
forall lore. Aliased lore => Body lore -> [Names]
bodyAliases BodyT lore
loopbody
        merge_names :: Names
merge_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ ((FParam lore, SubExp) -> VName)
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (FParam lore -> VName
forall attr. Param attr -> VName
paramName (FParam lore -> VName)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) ([(FParam lore, SubExp)] -> [VName])
-> [(FParam lore, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ [(FParam lore, SubExp)]
ctxmerge [(FParam lore, SubExp)]
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
valmerge
expAliases (Apply Name
_ [(SubExp, Diet)]
args [RetType lore]
t (Safety, SrcLoc, [SrcLoc])
_) =
  [(SubExp, Diet)] -> [TypeBase ExtShape Uniqueness] -> [Names]
forall shape.
[(SubExp, Diet)] -> [TypeBase shape Uniqueness] -> [Names]
funcallAliases [(SubExp, Diet)]
args ([TypeBase ExtShape Uniqueness] -> [Names])
-> [TypeBase ExtShape Uniqueness] -> [Names]
forall a b. (a -> b) -> a -> b
$ [RetType lore] -> [TypeBase ExtShape Uniqueness]
forall rt. IsRetType rt => [rt] -> [TypeBase ExtShape Uniqueness]
retTypeValues [RetType lore]
t
expAliases (Op Op lore
op) = Op lore -> [Names]
forall op. AliasedOp op => op -> [Names]
opAliases Op lore
op

returnAliases :: [TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases :: [TypeBase shape Uniqueness] -> [(Names, Diet)] -> [Names]
returnAliases [TypeBase shape Uniqueness]
rts [(Names, Diet)]
args = (TypeBase shape Uniqueness -> Names)
-> [TypeBase shape Uniqueness] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape Uniqueness -> Names
returnType' [TypeBase shape Uniqueness]
rts
  where returnType' :: TypeBase shape Uniqueness -> Names
returnType' (Array PrimType
_ shape
_ Uniqueness
Nonunique) =
          [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ ((Names, Diet) -> Names) -> [(Names, Diet)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((Names -> Diet -> Names) -> (Names, Diet) -> Names
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Names -> Diet -> Names
maskAliases) [(Names, Diet)]
args
        returnType' (Array PrimType
_ shape
_ Uniqueness
Unique) =
          Names
forall a. Monoid a => a
mempty
        returnType' (Prim PrimType
_) =
          Names
forall a. Monoid a => a
mempty
        returnType' Mem{} =
          [Char] -> Names
forall a. HasCallStack => [Char] -> a
error [Char]
"returnAliases Mem"

maskAliases :: Names -> Diet -> Names
maskAliases :: Names -> Diet -> Names
maskAliases Names
_   Diet
Consume = Names
forall a. Monoid a => a
mempty
maskAliases Names
_   Diet
ObservePrim = Names
forall a. Monoid a => a
mempty
maskAliases Names
als Diet
Observe = Names
als

consumedInStm :: Aliased lore => Stm lore -> Names
consumedInStm :: Stm lore -> Names
consumedInStm = Exp lore -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp (Exp lore -> Names) -> (Stm lore -> Exp lore) -> Stm lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp

consumedInExp :: (Aliased lore) => Exp lore -> Names
consumedInExp :: Exp lore -> Names
consumedInExp (Apply Name
_ [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) =
  [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (((SubExp, Diet) -> Names) -> [(SubExp, Diet)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map ((Names, Diet) -> Names
forall p. Monoid p => (p, Diet) -> p
consumeArg ((Names, Diet) -> Names)
-> ((SubExp, Diet) -> (Names, Diet)) -> (SubExp, Diet) -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Names) -> (SubExp, Diet) -> (Names, Diet)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first SubExp -> Names
subExpAliases) [(SubExp, Diet)]
args)
  where consumeArg :: (p, Diet) -> p
consumeArg (p
als, Diet
Consume) = p
als
        consumeArg (p, Diet)
_              = p
forall a. Monoid a => a
mempty
consumedInExp (If SubExp
_ BodyT lore
tb BodyT lore
fb IfAttr (BranchType lore)
_) =
  BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
tb Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody BodyT lore
fb
consumedInExp (DoLoop [(FParam lore, SubExp)]
_ [(FParam lore, SubExp)]
merge LoopForm lore
_ BodyT lore
_) =
  [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (((FParam lore, SubExp) -> Names)
-> [(FParam lore, SubExp)] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> Names
subExpAliases (SubExp -> Names)
-> ((FParam lore, SubExp) -> SubExp)
-> (FParam lore, SubExp)
-> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) ([(FParam lore, SubExp)] -> [Names])
-> [(FParam lore, SubExp)] -> [Names]
forall a b. (a -> b) -> a -> b
$
           ((FParam lore, SubExp) -> Bool)
-> [(FParam lore, SubExp)] -> [(FParam lore, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (TypeBase Shape Uniqueness -> Bool
forall shape. TypeBase shape Uniqueness -> Bool
unique (TypeBase Shape Uniqueness -> Bool)
-> ((FParam lore, SubExp) -> TypeBase Shape Uniqueness)
-> (FParam lore, SubExp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FParam lore -> TypeBase Shape Uniqueness
forall attr.
DeclTyped attr =>
Param attr -> TypeBase Shape Uniqueness
paramDeclType (FParam lore -> TypeBase Shape Uniqueness)
-> ((FParam lore, SubExp) -> FParam lore)
-> (FParam lore, SubExp)
-> TypeBase Shape Uniqueness
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam lore, SubExp) -> FParam lore
forall a b. (a, b) -> a
fst) [(FParam lore, SubExp)]
merge)
consumedInExp (BasicOp (Update VName
src Slice SubExp
_ SubExp
_)) = VName -> Names
oneName VName
src
consumedInExp (Op Op lore
op) = Op lore -> Names
forall op. AliasedOp op => op -> Names
consumedInOp Op lore
op
consumedInExp Exp lore
_ = Names
forall a. Monoid a => a
mempty

consumedByLambda :: Aliased lore => Lambda lore -> Names
consumedByLambda :: Lambda lore -> Names
consumedByLambda = Body lore -> Names
forall lore. Aliased lore => Body lore -> Names
consumedInBody (Body lore -> Names)
-> (Lambda lore -> Body lore) -> Lambda lore -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody

patternAliases :: AliasesOf attr => PatternT attr -> [Names]
patternAliases :: PatternT attr -> [Names]
patternAliases = (PatElemT attr -> Names) -> [PatElemT attr] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (attr -> Names
forall a. AliasesOf a => a -> Names
aliasesOf (attr -> Names)
-> (PatElemT attr -> attr) -> PatElemT attr -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT attr -> attr
forall attr. PatElemT attr -> attr
patElemAttr) ([PatElemT attr] -> [Names])
-> (PatternT attr -> [PatElemT attr]) -> PatternT attr -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT attr -> [PatElemT attr]
forall attr. PatternT attr -> [PatElemT attr]
patternElements

-- | Something that contains alias information.
class AliasesOf a where
  -- | The alias of the argument element.
  aliasesOf :: a -> Names

instance AliasesOf Names where
  aliasesOf :: Names -> Names
aliasesOf = Names -> Names
forall a. a -> a
id

instance AliasesOf attr => AliasesOf (PatElemT attr) where
  aliasesOf :: PatElemT attr -> Names
aliasesOf = attr -> Names
forall a. AliasesOf a => a -> Names
aliasesOf (attr -> Names)
-> (PatElemT attr -> attr) -> PatElemT attr -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT attr -> attr
forall attr. PatElemT attr -> attr
patElemAttr

class IsOp op => AliasedOp op where
  opAliases :: op -> [Names]
  consumedInOp :: op -> Names

instance AliasedOp () where
  opAliases :: () -> [Names]
opAliases () = []
  consumedInOp :: () -> Names
consumedInOp () = Names
forall a. Monoid a => a
mempty

class AliasedOp (OpWithAliases op) => CanBeAliased op where
  type OpWithAliases op :: Data.Kind.Type
  removeOpAliases :: OpWithAliases op -> op
  addOpAliases :: op -> OpWithAliases op

instance CanBeAliased () where
  type OpWithAliases () = ()
  removeOpAliases :: OpWithAliases () -> ()
removeOpAliases = OpWithAliases () -> ()
forall a. a -> a
id
  addOpAliases :: () -> OpWithAliases ()
addOpAliases = () -> OpWithAliases ()
forall a. a -> a
id