{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Analysis.SymbolTable
  ( SymbolTable (bindings, loopDepth, availableAtClosestLoop, simplifyMemory)
  , empty
  , fromScope
  , toScope

    -- * Entries
  , Entry
  , deepen
  , entryDepth
  , entryLetBoundDec
  , entryIsSize

    -- * Lookup
  , elem
  , lookup
  , lookupStm
  , lookupExp
  , lookupBasicOp
  , lookupType
  , lookupSubExp
  , lookupAliases
  , lookupLoopVar
  , available
  , consume
  , index
  , index'
  , Indexed(..)
  , indexedAddCerts
  , IndexOp(..)

    -- * Insertion
  , insertStm
  , insertStms
  , insertFParams
  , insertLParam
  , insertLoopVar

    -- * Misc
  , hideCertified
  )
  where

import Control.Arrow ((&&&))
import Control.Monad
import Data.Ord
import Data.Maybe
import Data.List (foldl', elemIndex)
import qualified Data.Map.Strict as M

import Prelude hiding (elem, lookup)

import Futhark.Analysis.PrimExp.Convert
import Futhark.IR hiding (FParam, lookupType)
import qualified Futhark.IR as AST

import qualified Futhark.IR.Prop.Aliases as Aliases

data SymbolTable lore = SymbolTable {
    SymbolTable lore -> Int
loopDepth :: Int
  , SymbolTable lore -> Map VName (Entry lore)
bindings :: M.Map VName (Entry lore)
  , SymbolTable lore -> Names
availableAtClosestLoop :: Names
    -- ^ Which names are available just before the most enclosing
    -- loop?
  , SymbolTable lore -> Bool
simplifyMemory :: Bool
    -- ^ We are in a situation where we should
    -- simplify/hoist/un-existentialise memory as much as possible -
    -- typically, inside a kernel.
  }

instance Semigroup (SymbolTable lore) where
  SymbolTable lore
table1 <> :: SymbolTable lore -> SymbolTable lore -> SymbolTable lore
<> SymbolTable lore
table2 =
    SymbolTable :: forall lore.
Int -> Map VName (Entry lore) -> Names -> Bool -> SymbolTable lore
SymbolTable { loopDepth :: Int
loopDepth = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SymbolTable lore -> Int
forall lore. SymbolTable lore -> Int
loopDepth SymbolTable lore
table1) (SymbolTable lore -> Int
forall lore. SymbolTable lore -> Int
loopDepth SymbolTable lore
table2)
                , bindings :: Map VName (Entry lore)
bindings = SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
table1 Map VName (Entry lore)
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a. Semigroup a => a -> a -> a
<> SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
table2
                , availableAtClosestLoop :: Names
availableAtClosestLoop = SymbolTable lore -> Names
forall lore. SymbolTable lore -> Names
availableAtClosestLoop SymbolTable lore
table1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>
                                           SymbolTable lore -> Names
forall lore. SymbolTable lore -> Names
availableAtClosestLoop SymbolTable lore
table2
                , simplifyMemory :: Bool
simplifyMemory = SymbolTable lore -> Bool
forall lore. SymbolTable lore -> Bool
simplifyMemory SymbolTable lore
table1 Bool -> Bool -> Bool
|| SymbolTable lore -> Bool
forall lore. SymbolTable lore -> Bool
simplifyMemory SymbolTable lore
table2
                }

instance Monoid (SymbolTable lore) where
  mempty :: SymbolTable lore
mempty = SymbolTable lore
forall lore. SymbolTable lore
empty

empty :: SymbolTable lore
empty :: SymbolTable lore
empty = Int -> Map VName (Entry lore) -> Names -> Bool -> SymbolTable lore
forall lore.
Int -> Map VName (Entry lore) -> Names -> Bool -> SymbolTable lore
SymbolTable Int
0 Map VName (Entry lore)
forall k a. Map k a
M.empty Names
forall a. Monoid a => a
mempty Bool
False

fromScope :: ASTLore lore => Scope lore -> SymbolTable lore
fromScope :: Scope lore -> SymbolTable lore
fromScope = (SymbolTable lore -> VName -> NameInfo lore -> SymbolTable lore)
-> SymbolTable lore -> Scope lore -> SymbolTable lore
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey' SymbolTable lore -> VName -> NameInfo lore -> SymbolTable lore
forall lore.
ASTLore lore =>
SymbolTable lore -> VName -> NameInfo lore -> SymbolTable lore
insertFreeVar' SymbolTable lore
forall lore. SymbolTable lore
empty
  where insertFreeVar' :: SymbolTable lore -> VName -> NameInfo lore -> SymbolTable lore
insertFreeVar' SymbolTable lore
m VName
k NameInfo lore
dec = VName -> NameInfo lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> NameInfo lore -> SymbolTable lore -> SymbolTable lore
insertFreeVar VName
k NameInfo lore
dec SymbolTable lore
m

toScope :: SymbolTable lore -> Scope lore
toScope :: SymbolTable lore -> Scope lore
toScope = (Entry lore -> NameInfo lore)
-> Map VName (Entry lore) -> Scope lore
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Entry lore -> NameInfo lore
forall lore. Entry lore -> NameInfo lore
entryInfo (Map VName (Entry lore) -> Scope lore)
-> (SymbolTable lore -> Map VName (Entry lore))
-> SymbolTable lore
-> Scope lore
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings

deepen :: SymbolTable lore -> SymbolTable lore
deepen :: SymbolTable lore -> SymbolTable lore
deepen SymbolTable lore
vtable = SymbolTable lore
vtable { loopDepth :: Int
loopDepth = SymbolTable lore -> Int
forall lore. SymbolTable lore -> Int
loopDepth SymbolTable lore
vtable Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1,
                         availableAtClosestLoop :: Names
availableAtClosestLoop = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (Entry lore) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Entry lore) -> [VName])
-> Map VName (Entry lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable
                       }

-- | The result of indexing a delayed array.
data Indexed = Indexed Certificates (PrimExp VName)
               -- ^ A PrimExp based on the indexes (that is, without
               -- accessing any actual array).
             | IndexedArray Certificates VName [PrimExp VName]
               -- ^ The indexing corresponds to another (perhaps more
               -- advantageous) array.

indexedAddCerts :: Certificates -> Indexed -> Indexed
indexedAddCerts :: Certificates -> Indexed -> Indexed
indexedAddCerts Certificates
cs1 (Indexed Certificates
cs2 PrimExp VName
v) = Certificates -> PrimExp VName -> Indexed
Indexed (Certificates
cs1Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
cs2) PrimExp VName
v
indexedAddCerts Certificates
cs1 (IndexedArray Certificates
cs2 VName
arr [PrimExp VName]
v) = Certificates -> VName -> [PrimExp VName] -> Indexed
IndexedArray (Certificates
cs1Certificates -> Certificates -> Certificates
forall a. Semigroup a => a -> a -> a
<>Certificates
cs2) VName
arr [PrimExp VName]
v

instance FreeIn Indexed where
  freeIn' :: Indexed -> FV
freeIn' (Indexed Certificates
cs PrimExp VName
v) = Certificates -> FV
forall a. FreeIn a => a -> FV
freeIn' Certificates
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> PrimExp VName -> FV
forall a. FreeIn a => a -> FV
freeIn' PrimExp VName
v
  freeIn' (IndexedArray Certificates
cs VName
arr [PrimExp VName]
v) = Certificates -> FV
forall a. FreeIn a => a -> FV
freeIn' Certificates
cs FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> VName -> FV
forall a. FreeIn a => a -> FV
freeIn' VName
arr FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> [PrimExp VName] -> FV
forall a. FreeIn a => a -> FV
freeIn' [PrimExp VName]
v

-- | Indexing a delayed array if possible.
type IndexArray = [PrimExp VName] -> Maybe Indexed

data Entry lore =
  Entry { Entry lore -> Bool
entryConsumed :: Bool
          -- ^ True if consumed.
        , Entry lore -> Int
entryDepth :: Int
        , Entry lore -> Bool
entryIsSize :: Bool
          -- ^ True if this name has been used as an array size,
          -- implying that it is non-negative.
        , Entry lore -> EntryType lore
entryType :: EntryType lore
        }

data EntryType lore = LoopVar (LoopVarEntry lore)
                    | LetBound (LetBoundEntry lore)
                    | FParam (FParamEntry lore)
                    | LParam (LParamEntry lore)
                    | FreeVar (FreeVarEntry lore)

data LoopVarEntry lore =
  LoopVarEntry { LoopVarEntry lore -> IntType
loopVarType     :: IntType
               , LoopVarEntry lore -> SubExp
loopVarBound    :: SubExp
               }

data LetBoundEntry lore =
  LetBoundEntry { LetBoundEntry lore -> LetDec lore
letBoundDec      :: LetDec lore
                , LetBoundEntry lore -> Names
letBoundAliases  :: Names
                , LetBoundEntry lore -> Stm lore
letBoundStm      :: Stm lore
                , LetBoundEntry lore -> Int -> IndexArray
letBoundIndex    :: Int -> IndexArray
                -- ^ Index a delayed array, if possible.
                }

data FParamEntry lore =
  FParamEntry { FParamEntry lore -> FParamInfo lore
fparamDec      :: FParamInfo lore
              , FParamEntry lore -> Names
fparamAliases  :: Names
              }

data LParamEntry lore =
  LParamEntry { LParamEntry lore -> LParamInfo lore
lparamDec      :: LParamInfo lore
              , LParamEntry lore -> IndexArray
lparamIndex    :: IndexArray
              }

data FreeVarEntry lore =
  FreeVarEntry { FreeVarEntry lore -> NameInfo lore
freeVarDec      :: NameInfo lore
               , FreeVarEntry lore -> VName -> IndexArray
freeVarIndex    :: VName -> IndexArray
                -- ^ Index a delayed array, if possible.
               }

instance ASTLore lore => Typed (Entry lore) where
  typeOf :: Entry lore -> Type
typeOf = NameInfo lore -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo lore -> Type)
-> (Entry lore -> NameInfo lore) -> Entry lore -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry lore -> NameInfo lore
forall lore. Entry lore -> NameInfo lore
entryInfo

entryInfo :: Entry lore -> NameInfo lore
entryInfo :: Entry lore -> NameInfo lore
entryInfo Entry lore
e = case Entry lore -> EntryType lore
forall lore. Entry lore -> EntryType lore
entryType Entry lore
e of
                LetBound LetBoundEntry lore
entry -> LetDec lore -> NameInfo lore
forall lore. LetDec lore -> NameInfo lore
LetName (LetDec lore -> NameInfo lore) -> LetDec lore -> NameInfo lore
forall a b. (a -> b) -> a -> b
$ LetBoundEntry lore -> LetDec lore
forall lore. LetBoundEntry lore -> LetDec lore
letBoundDec LetBoundEntry lore
entry
                LoopVar LoopVarEntry lore
entry -> IntType -> NameInfo lore
forall lore. IntType -> NameInfo lore
IndexName (IntType -> NameInfo lore) -> IntType -> NameInfo lore
forall a b. (a -> b) -> a -> b
$ LoopVarEntry lore -> IntType
forall lore. LoopVarEntry lore -> IntType
loopVarType LoopVarEntry lore
entry
                FParam FParamEntry lore
entry -> FParamInfo lore -> NameInfo lore
forall lore. FParamInfo lore -> NameInfo lore
FParamName (FParamInfo lore -> NameInfo lore)
-> FParamInfo lore -> NameInfo lore
forall a b. (a -> b) -> a -> b
$ FParamEntry lore -> FParamInfo lore
forall lore. FParamEntry lore -> FParamInfo lore
fparamDec FParamEntry lore
entry
                LParam LParamEntry lore
entry -> LParamInfo lore -> NameInfo lore
forall lore. LParamInfo lore -> NameInfo lore
LParamName (LParamInfo lore -> NameInfo lore)
-> LParamInfo lore -> NameInfo lore
forall a b. (a -> b) -> a -> b
$ LParamEntry lore -> LParamInfo lore
forall lore. LParamEntry lore -> LParamInfo lore
lparamDec LParamEntry lore
entry
                FreeVar FreeVarEntry lore
entry -> FreeVarEntry lore -> NameInfo lore
forall lore. FreeVarEntry lore -> NameInfo lore
freeVarDec FreeVarEntry lore
entry

isLetBound :: Entry lore -> Maybe (LetBoundEntry lore)
isLetBound :: Entry lore -> Maybe (LetBoundEntry lore)
isLetBound Entry lore
e = case Entry lore -> EntryType lore
forall lore. Entry lore -> EntryType lore
entryType Entry lore
e of
                 LetBound LetBoundEntry lore
entry -> LetBoundEntry lore -> Maybe (LetBoundEntry lore)
forall a. a -> Maybe a
Just LetBoundEntry lore
entry
                 EntryType lore
_ -> Maybe (LetBoundEntry lore)
forall a. Maybe a
Nothing

entryStm :: Entry lore -> Maybe (Stm lore)
entryStm :: Entry lore -> Maybe (Stm lore)
entryStm = (LetBoundEntry lore -> Stm lore)
-> Maybe (LetBoundEntry lore) -> Maybe (Stm lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LetBoundEntry lore -> Stm lore
forall lore. LetBoundEntry lore -> Stm lore
letBoundStm (Maybe (LetBoundEntry lore) -> Maybe (Stm lore))
-> (Entry lore -> Maybe (LetBoundEntry lore))
-> Entry lore
-> Maybe (Stm lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry lore -> Maybe (LetBoundEntry lore)
forall lore. Entry lore -> Maybe (LetBoundEntry lore)
isLetBound

entryLetBoundDec :: Entry lore -> Maybe (LetDec lore)
entryLetBoundDec :: Entry lore -> Maybe (LetDec lore)
entryLetBoundDec = (LetBoundEntry lore -> LetDec lore)
-> Maybe (LetBoundEntry lore) -> Maybe (LetDec lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LetBoundEntry lore -> LetDec lore
forall lore. LetBoundEntry lore -> LetDec lore
letBoundDec (Maybe (LetBoundEntry lore) -> Maybe (LetDec lore))
-> (Entry lore -> Maybe (LetBoundEntry lore))
-> Entry lore
-> Maybe (LetDec lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry lore -> Maybe (LetBoundEntry lore)
forall lore. Entry lore -> Maybe (LetBoundEntry lore)
isLetBound

elem :: VName -> SymbolTable lore -> Bool
elem :: VName -> SymbolTable lore -> Bool
elem VName
name = Maybe (Entry lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Entry lore) -> Bool)
-> (SymbolTable lore -> Maybe (Entry lore))
-> SymbolTable lore
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
lookup VName
name

lookup :: VName -> SymbolTable lore -> Maybe (Entry lore)
lookup :: VName -> SymbolTable lore -> Maybe (Entry lore)
lookup VName
name = VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (SymbolTable lore -> Map VName (Entry lore))
-> SymbolTable lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings

lookupStm :: VName -> SymbolTable lore -> Maybe (Stm lore)
lookupStm :: VName -> SymbolTable lore -> Maybe (Stm lore)
lookupStm VName
name SymbolTable lore
vtable = Entry lore -> Maybe (Stm lore)
forall lore. Entry lore -> Maybe (Stm lore)
entryStm (Entry lore -> Maybe (Stm lore))
-> Maybe (Entry lore) -> Maybe (Stm lore)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
lookup VName
name SymbolTable lore
vtable

lookupExp :: VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
lookupExp :: VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
lookupExp VName
name SymbolTable lore
vtable = (Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp (Stm lore -> Exp lore)
-> (Stm lore -> Certificates)
-> Stm lore
-> (Exp lore, Certificates)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts) (Stm lore -> (Exp lore, Certificates))
-> Maybe (Stm lore) -> Maybe (Exp lore, Certificates)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (Stm lore)
forall lore. VName -> SymbolTable lore -> Maybe (Stm lore)
lookupStm VName
name SymbolTable lore
vtable

lookupBasicOp :: VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
lookupBasicOp :: VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
lookupBasicOp VName
name SymbolTable lore
vtable = case VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
lookupExp VName
name SymbolTable lore
vtable of
  Just (BasicOp BasicOp
e, Certificates
cs) -> (BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall a. a -> Maybe a
Just (BasicOp
e, Certificates
cs)
  Maybe (Exp lore, Certificates)
_                    -> Maybe (BasicOp, Certificates)
forall a. Maybe a
Nothing

lookupType :: ASTLore lore => VName -> SymbolTable lore -> Maybe Type
lookupType :: VName -> SymbolTable lore -> Maybe Type
lookupType VName
name SymbolTable lore
vtable = Entry lore -> Type
forall t. Typed t => t -> Type
typeOf (Entry lore -> Type) -> Maybe (Entry lore) -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
lookup VName
name SymbolTable lore
vtable

lookupSubExpType :: ASTLore lore => SubExp -> SymbolTable lore -> Maybe Type
lookupSubExpType :: SubExp -> SymbolTable lore -> Maybe Type
lookupSubExpType (Var VName
v) = VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
lookupType VName
v
lookupSubExpType (Constant PrimValue
v) = Maybe Type -> SymbolTable lore -> Maybe Type
forall a b. a -> b -> a
const (Maybe Type -> SymbolTable lore -> Maybe Type)
-> Maybe Type -> SymbolTable lore -> Maybe Type
forall a b. (a -> b) -> a -> b
$ Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v

lookupSubExp :: VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
lookupSubExp :: VName -> SymbolTable lore -> Maybe (SubExp, Certificates)
lookupSubExp VName
name SymbolTable lore
vtable = do
  (Exp lore
e,Certificates
cs) <- VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
lookupExp VName
name SymbolTable lore
vtable
  case Exp lore
e of
    BasicOp (SubExp SubExp
se) -> (SubExp, Certificates) -> Maybe (SubExp, Certificates)
forall a. a -> Maybe a
Just (SubExp
se,Certificates
cs)
    Exp lore
_                   -> Maybe (SubExp, Certificates)
forall a. Maybe a
Nothing

lookupAliases :: VName -> SymbolTable lore -> Names
lookupAliases :: VName -> SymbolTable lore -> Names
lookupAliases VName
name SymbolTable lore
vtable =
  case Entry lore -> EntryType lore
forall lore. Entry lore -> EntryType lore
entryType (Entry lore -> EntryType lore)
-> Maybe (Entry lore) -> Maybe (EntryType lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable) of
    Just (LetBound LetBoundEntry lore
e) -> LetBoundEntry lore -> Names
forall lore. LetBoundEntry lore -> Names
letBoundAliases LetBoundEntry lore
e
    Just (FParam FParamEntry lore
e)   -> FParamEntry lore -> Names
forall lore. FParamEntry lore -> Names
fparamAliases FParamEntry lore
e
    Maybe (EntryType lore)
_                 -> Names
forall a. Monoid a => a
mempty

-- | If the given variable name is the name of a 'ForLoop' parameter,
-- then return the bound of that loop.
lookupLoopVar :: VName -> SymbolTable lore -> Maybe SubExp
lookupLoopVar :: VName -> SymbolTable lore -> Maybe SubExp
lookupLoopVar VName
name SymbolTable lore
vtable = do
  LoopVar LoopVarEntry lore
e <- Entry lore -> EntryType lore
forall lore. Entry lore -> EntryType lore
entryType (Entry lore -> EntryType lore)
-> Maybe (Entry lore) -> Maybe (EntryType lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable)
  SubExp -> Maybe SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ LoopVarEntry lore -> SubExp
forall lore. LoopVarEntry lore -> SubExp
loopVarBound LoopVarEntry lore
e

-- | In symbol table and not consumed.
available :: VName -> SymbolTable lore -> Bool
available :: VName -> SymbolTable lore -> Bool
available VName
name = Bool -> (Entry lore -> Bool) -> Maybe (Entry lore) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Bool -> Bool
not (Bool -> Bool) -> (Entry lore -> Bool) -> Entry lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry lore -> Bool
forall lore. Entry lore -> Bool
entryConsumed) (Maybe (Entry lore) -> Bool)
-> (SymbolTable lore -> Maybe (Entry lore))
-> SymbolTable lore
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName (Entry lore) -> Maybe (Entry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (Entry lore) -> Maybe (Entry lore))
-> (SymbolTable lore -> Map VName (Entry lore))
-> SymbolTable lore
-> Maybe (Entry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings

index :: ASTLore lore => VName -> [SubExp] -> SymbolTable lore
      -> Maybe Indexed
index :: VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
index VName
name [SubExp]
is SymbolTable lore
table = do
  [PrimExp VName]
is' <- (SubExp -> Maybe (PrimExp VName))
-> [SubExp] -> Maybe [PrimExp VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe (PrimExp VName)
asPrimExp [SubExp]
is
  VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
index' VName
name [PrimExp VName]
is' SymbolTable lore
table
  where asPrimExp :: SubExp -> Maybe (PrimExp VName)
asPrimExp SubExp
i = do
          Prim PrimType
t <- SubExp -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
SubExp -> SymbolTable lore -> Maybe Type
lookupSubExpType SubExp
i SymbolTable lore
table
          PrimExp VName -> Maybe (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> Maybe (PrimExp VName))
-> PrimExp VName -> Maybe (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
i

index' :: VName -> [PrimExp VName] -> SymbolTable lore
       -> Maybe Indexed
index' :: VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
index' VName
name [PrimExp VName]
is SymbolTable lore
vtable = do
  Entry lore
entry <- VName -> SymbolTable lore -> Maybe (Entry lore)
forall lore. VName -> SymbolTable lore -> Maybe (Entry lore)
lookup VName
name SymbolTable lore
vtable
  case Entry lore -> EntryType lore
forall lore. Entry lore -> EntryType lore
entryType Entry lore
entry of
    LetBound LetBoundEntry lore
entry' |
      Just Int
k <- VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex VName
name ([VName] -> Maybe Int) -> [VName] -> Maybe Int
forall a b. (a -> b) -> a -> b
$ PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$
                Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern (Stm lore -> PatternT (LetDec lore))
-> Stm lore -> PatternT (LetDec lore)
forall a b. (a -> b) -> a -> b
$ LetBoundEntry lore -> Stm lore
forall lore. LetBoundEntry lore -> Stm lore
letBoundStm LetBoundEntry lore
entry' ->
        LetBoundEntry lore -> Int -> IndexArray
forall lore. LetBoundEntry lore -> Int -> IndexArray
letBoundIndex LetBoundEntry lore
entry' Int
k [PrimExp VName]
is
    FreeVar FreeVarEntry lore
entry' ->
      FreeVarEntry lore -> VName -> IndexArray
forall lore. FreeVarEntry lore -> VName -> IndexArray
freeVarIndex FreeVarEntry lore
entry' VName
name [PrimExp VName]
is
    LParam LParamEntry lore
entry' -> LParamEntry lore -> IndexArray
forall lore. LParamEntry lore -> IndexArray
lparamIndex LParamEntry lore
entry' [PrimExp VName]
is
    EntryType lore
_ -> Maybe Indexed
forall a. Maybe a
Nothing

class IndexOp op where
  indexOp :: (ASTLore lore, IndexOp (Op lore)) =>
             SymbolTable lore -> Int -> op
          -> [PrimExp VName] -> Maybe Indexed
  indexOp SymbolTable lore
_ Int
_ op
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

instance IndexOp () where

indexExp :: (IndexOp (Op lore), ASTLore lore) =>
            SymbolTable lore -> Exp lore -> Int -> IndexArray

indexExp :: SymbolTable lore -> Exp lore -> Int -> IndexArray
indexExp SymbolTable lore
vtable (Op Op lore
op) Int
k [PrimExp VName]
is =
  SymbolTable lore -> Int -> Op lore -> IndexArray
forall op lore.
(IndexOp op, ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore -> Int -> op -> IndexArray
indexOp SymbolTable lore
vtable Int
k Op lore
op [PrimExp VName]
is

indexExp SymbolTable lore
_ (BasicOp (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it)) Int
_ [PrimExp VName
i]
  | IntType IntType
from_it <- PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp VName
i =
      Indexed -> Maybe Indexed
forall a. a -> Maybe a
Just (Indexed -> Maybe Indexed) -> Indexed -> Maybe Indexed
forall a b. (a -> b) -> a -> b
$ Certificates -> PrimExp VName -> Indexed
Indexed Certificates
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed) -> PrimExp VName -> Indexed
forall a b. (a -> b) -> a -> b
$
       ConvOp -> PrimExp VName -> PrimExp VName
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
from_it IntType
to_it) PrimExp VName
i
       PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
       PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x

indexExp SymbolTable lore
table (BasicOp (Replicate (Shape [SubExp]
ds) SubExp
v)) Int
_ [PrimExp VName]
is
  | [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [PrimExp VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimExp VName]
is,
    Just (Prim PrimType
t) <- SubExp -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
SubExp -> SymbolTable lore -> Maybe Type
lookupSubExpType SubExp
v SymbolTable lore
table =
      Indexed -> Maybe Indexed
forall a. a -> Maybe a
Just (Indexed -> Maybe Indexed) -> Indexed -> Maybe Indexed
forall a b. (a -> b) -> a -> b
$ Certificates -> PrimExp VName -> Indexed
Indexed Certificates
forall a. Monoid a => a
mempty (PrimExp VName -> Indexed) -> PrimExp VName -> Indexed
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
v

indexExp SymbolTable lore
table (BasicOp (Replicate (Shape [SubExp
_]) (Var VName
v))) Int
_ (PrimExp VName
_:[PrimExp VName]
is) =
  VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
index' VName
v [PrimExp VName]
is SymbolTable lore
table

indexExp SymbolTable lore
table (BasicOp (Reshape ShapeChange SubExp
newshape VName
v)) Int
_ [PrimExp VName]
is
  | Just [SubExp]
oldshape <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable lore -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
lookupType VName
v SymbolTable lore
table =
      let is' :: [PrimExp VName]
is' =
            [PrimExp VName]
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex ((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) [SubExp]
oldshape)
                         ((SubExp -> PrimExp VName) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> [PrimExp VName]) -> [SubExp] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
newshape)
                         [PrimExp VName]
is
      in VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
index' VName
v [PrimExp VName]
is' SymbolTable lore
table

indexExp SymbolTable lore
table (BasicOp (Index VName
v Slice SubExp
slice)) Int
_ [PrimExp VName]
is =
  VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
forall lore.
VName -> [PrimExp VName] -> SymbolTable lore -> Maybe Indexed
index' VName
v (Slice SubExp -> [PrimExp VName] -> [PrimExp VName]
adjust Slice SubExp
slice [PrimExp VName]
is) SymbolTable lore
table
  where adjust :: Slice SubExp -> [PrimExp VName] -> [PrimExp VName]
adjust (DimFix SubExp
j:Slice SubExp
js') [PrimExp VName]
is' =
          SubExp -> PrimExp VName
pe SubExp
j PrimExp VName -> [PrimExp VName] -> [PrimExp VName]
forall a. a -> [a] -> [a]
: Slice SubExp -> [PrimExp VName] -> [PrimExp VName]
adjust Slice SubExp
js' [PrimExp VName]
is'
        adjust (DimSlice SubExp
j SubExp
_ SubExp
s:Slice SubExp
js') (PrimExp VName
i:[PrimExp VName]
is') =
          let i_t_s :: PrimExp VName
i_t_s = PrimExp VName
i PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
* SubExp -> PrimExp VName
pe SubExp
s
              j_p_i_t_s :: PrimExp VName
j_p_i_t_s = SubExp -> PrimExp VName
pe SubExp
j PrimExp VName -> PrimExp VName -> PrimExp VName
forall a. Num a => a -> a -> a
+ PrimExp VName
i_t_s
          in PrimExp VName
j_p_i_t_s PrimExp VName -> [PrimExp VName] -> [PrimExp VName]
forall a. a -> [a] -> [a]
: Slice SubExp -> [PrimExp VName] -> [PrimExp VName]
adjust Slice SubExp
js' [PrimExp VName]
is'
        adjust Slice SubExp
_ [PrimExp VName]
_ = []

        pe :: SubExp -> PrimExp VName
pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
Int32)

indexExp SymbolTable lore
_ Exp lore
_ Int
_ [PrimExp VName]
_ = Maybe Indexed
forall a. Maybe a
Nothing

defBndEntry :: (ASTLore lore, IndexOp (Op lore)) =>
               SymbolTable lore
            -> PatElem lore
            -> Names
            -> Stm lore
            -> LetBoundEntry lore
defBndEntry :: SymbolTable lore
-> PatElem lore -> Names -> Stm lore -> LetBoundEntry lore
defBndEntry SymbolTable lore
vtable PatElem lore
patElem Names
als Stm lore
bnd =
  LetBoundEntry :: forall lore.
LetDec lore
-> Names -> Stm lore -> (Int -> IndexArray) -> LetBoundEntry lore
LetBoundEntry {
      letBoundDec :: LetDec lore
letBoundDec = PatElem lore -> LetDec lore
forall dec. PatElemT dec -> dec
patElemDec PatElem lore
patElem
    , letBoundAliases :: Names
letBoundAliases = Names
als
    , letBoundStm :: Stm lore
letBoundStm = Stm lore
bnd
    , letBoundIndex :: Int -> IndexArray
letBoundIndex = \Int
k -> (Indexed -> Indexed) -> Maybe Indexed -> Maybe Indexed
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Indexed -> Indexed
indexedAddCerts (StmAux (ExpDec lore) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts (StmAux (ExpDec lore) -> Certificates)
-> StmAux (ExpDec lore) -> Certificates
forall a b. (a -> b) -> a -> b
$ Stm lore -> StmAux (ExpDec lore)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm lore
bnd)) (Maybe Indexed -> Maybe Indexed) -> IndexArray -> IndexArray
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                            SymbolTable lore -> Exp lore -> Int -> IndexArray
forall lore.
(IndexOp (Op lore), ASTLore lore) =>
SymbolTable lore -> Exp lore -> Int -> IndexArray
indexExp SymbolTable lore
vtable (Stm lore -> Exp lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
bnd) Int
k
    }

bindingEntries :: (ASTLore lore, Aliases.Aliased lore, IndexOp (Op lore)) =>
                  Stm lore -> SymbolTable lore
               -> [LetBoundEntry lore]
bindingEntries :: Stm lore -> SymbolTable lore -> [LetBoundEntry lore]
bindingEntries bnd :: Stm lore
bnd@(Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
_) SymbolTable lore
vtable = do
  PatElemT (LetDec lore)
pat_elem <- Pattern lore -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
pat
  LetBoundEntry lore -> [LetBoundEntry lore]
forall (m :: * -> *) a. Monad m => a -> m a
return (LetBoundEntry lore -> [LetBoundEntry lore])
-> LetBoundEntry lore -> [LetBoundEntry lore]
forall a b. (a -> b) -> a -> b
$ SymbolTable lore
-> PatElemT (LetDec lore)
-> Names
-> Stm lore
-> LetBoundEntry lore
forall lore.
(ASTLore lore, IndexOp (Op lore)) =>
SymbolTable lore
-> PatElem lore -> Names -> Stm lore -> LetBoundEntry lore
defBndEntry SymbolTable lore
vtable PatElemT (LetDec lore)
pat_elem (PatElemT (LetDec lore) -> Names
forall a. AliasesOf a => a -> Names
Aliases.aliasesOf PatElemT (LetDec lore)
pat_elem) Stm lore
bnd

adjustSeveral :: Ord k => (v -> v) -> [k] -> M.Map k v -> M.Map k v
adjustSeveral :: (v -> v) -> [k] -> Map k v -> Map k v
adjustSeveral v -> v
f = (Map k v -> [k] -> Map k v) -> [k] -> Map k v -> Map k v
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Map k v -> [k] -> Map k v) -> [k] -> Map k v -> Map k v)
-> (Map k v -> [k] -> Map k v) -> [k] -> Map k v -> Map k v
forall a b. (a -> b) -> a -> b
$ (Map k v -> k -> Map k v) -> Map k v -> [k] -> Map k v
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Map k v -> k -> Map k v) -> Map k v -> [k] -> Map k v)
-> (Map k v -> k -> Map k v) -> Map k v -> [k] -> Map k v
forall a b. (a -> b) -> a -> b
$ (k -> Map k v -> Map k v) -> Map k v -> k -> Map k v
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((k -> Map k v -> Map k v) -> Map k v -> k -> Map k v)
-> (k -> Map k v -> Map k v) -> Map k v -> k -> Map k v
forall a b. (a -> b) -> a -> b
$ (v -> v) -> k -> Map k v -> Map k v
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust v -> v
f

insertEntry :: ASTLore lore =>
               VName -> EntryType lore -> SymbolTable lore
            -> SymbolTable lore
insertEntry :: VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
insertEntry VName
name EntryType lore
entry SymbolTable lore
vtable =
  let entry' :: Entry lore
entry' = Entry :: forall lore. Bool -> Int -> Bool -> EntryType lore -> Entry lore
Entry { entryConsumed :: Bool
entryConsumed = Bool
False
                     , entryDepth :: Int
entryDepth = SymbolTable lore -> Int
forall lore. SymbolTable lore -> Int
loopDepth SymbolTable lore
vtable
                     , entryIsSize :: Bool
entryIsSize = Bool
False
                     , entryType :: EntryType lore
entryType = EntryType lore
entry
                     }
      dims :: [VName]
dims = (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
subExpVar ([SubExp] -> [VName]) -> [SubExp] -> [VName]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Entry lore -> Type
forall t. Typed t => t -> Type
typeOf Entry lore
entry'
      isSize :: Entry lore -> Entry lore
isSize Entry lore
e = Entry lore
e { entryIsSize :: Bool
entryIsSize = Bool
True }
  in SymbolTable lore
vtable { bindings :: Map VName (Entry lore)
bindings = (Entry lore -> Entry lore)
-> [VName] -> Map VName (Entry lore) -> Map VName (Entry lore)
forall k v. Ord k => (v -> v) -> [k] -> Map k v -> Map k v
adjustSeveral Entry lore -> Entry lore
forall lore. Entry lore -> Entry lore
isSize [VName]
dims (Map VName (Entry lore) -> Map VName (Entry lore))
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a b. (a -> b) -> a -> b
$
                         VName
-> Entry lore -> Map VName (Entry lore) -> Map VName (Entry lore)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name Entry lore
entry' (Map VName (Entry lore) -> Map VName (Entry lore))
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable }

insertEntries :: ASTLore lore =>
                 [(VName, EntryType lore)] -> SymbolTable lore
              -> SymbolTable lore
insertEntries :: [(VName, EntryType lore)] -> SymbolTable lore -> SymbolTable lore
insertEntries [(VName, EntryType lore)]
entries SymbolTable lore
vtable =
  (SymbolTable lore -> (VName, EntryType lore) -> SymbolTable lore)
-> SymbolTable lore
-> [(VName, EntryType lore)]
-> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable lore -> (VName, EntryType lore) -> SymbolTable lore
forall lore.
ASTLore lore =>
SymbolTable lore -> (VName, EntryType lore) -> SymbolTable lore
add SymbolTable lore
vtable [(VName, EntryType lore)]
entries
  where add :: SymbolTable lore -> (VName, EntryType lore) -> SymbolTable lore
add SymbolTable lore
vtable' (VName
name, EntryType lore
entry) = VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
insertEntry VName
name EntryType lore
entry SymbolTable lore
vtable'

insertStm :: (ASTLore lore, IndexOp (Op lore), Aliases.Aliased lore) =>
             Stm lore
          -> SymbolTable lore
          -> SymbolTable lore
insertStm :: Stm lore -> SymbolTable lore -> SymbolTable lore
insertStm Stm lore
stm SymbolTable lore
vtable =
  (SymbolTable lore -> [VName] -> SymbolTable lore)
-> [VName] -> SymbolTable lore -> SymbolTable lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable lore -> VName -> SymbolTable lore)
-> SymbolTable lore -> [VName] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((SymbolTable lore -> VName -> SymbolTable lore)
 -> SymbolTable lore -> [VName] -> SymbolTable lore)
-> (SymbolTable lore -> VName -> SymbolTable lore)
-> SymbolTable lore
-> [VName]
-> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ (VName -> SymbolTable lore -> SymbolTable lore)
-> SymbolTable lore -> VName -> SymbolTable lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> SymbolTable lore -> SymbolTable lore
forall lore. VName -> SymbolTable lore -> SymbolTable lore
consume) (Names -> [VName]
namesToList Names
stm_consumed) (SymbolTable lore -> SymbolTable lore)
-> SymbolTable lore -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$
  (SymbolTable lore -> [PatElemT (LetDec lore)] -> SymbolTable lore)
-> [PatElemT (LetDec lore)] -> SymbolTable lore -> SymbolTable lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((SymbolTable lore -> PatElemT (LetDec lore) -> SymbolTable lore)
-> SymbolTable lore -> [PatElemT (LetDec lore)] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable lore -> PatElemT (LetDec lore) -> SymbolTable lore
forall dec lore.
AliasesOf dec =>
SymbolTable lore -> PatElemT dec -> SymbolTable lore
addRevAliases) (PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements (PatternT (LetDec lore) -> [PatElemT (LetDec lore)])
-> PatternT (LetDec lore) -> [PatElemT (LetDec lore)]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm) (SymbolTable lore -> SymbolTable lore)
-> SymbolTable lore -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$
  [(VName, EntryType lore)] -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
[(VName, EntryType lore)] -> SymbolTable lore -> SymbolTable lore
insertEntries ([VName] -> [EntryType lore] -> [(VName, EntryType lore)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names ([EntryType lore] -> [(VName, EntryType lore)])
-> [EntryType lore] -> [(VName, EntryType lore)]
forall a b. (a -> b) -> a -> b
$ (LetBoundEntry lore -> EntryType lore)
-> [LetBoundEntry lore] -> [EntryType lore]
forall a b. (a -> b) -> [a] -> [b]
map LetBoundEntry lore -> EntryType lore
forall lore. LetBoundEntry lore -> EntryType lore
LetBound ([LetBoundEntry lore] -> [EntryType lore])
-> [LetBoundEntry lore] -> [EntryType lore]
forall a b. (a -> b) -> a -> b
$ Stm lore -> SymbolTable lore -> [LetBoundEntry lore]
forall lore.
(ASTLore lore, Aliased lore, IndexOp (Op lore)) =>
Stm lore -> SymbolTable lore -> [LetBoundEntry lore]
bindingEntries Stm lore
stm SymbolTable lore
vtable) SymbolTable lore
vtable
  where names :: [VName]
names = PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> PatternT (LetDec lore) -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern Stm lore
stm
        stm_consumed :: Names
stm_consumed = Names -> SymbolTable lore -> Names
forall lore. Names -> SymbolTable lore -> Names
expandAliases (Stm lore -> Names
forall lore. Aliased lore => Stm lore -> Names
Aliases.consumedInStm Stm lore
stm) SymbolTable lore
vtable
        addRevAliases :: SymbolTable lore -> PatElemT dec -> SymbolTable lore
addRevAliases SymbolTable lore
vtable' PatElemT dec
pe =
          SymbolTable lore
vtable' { bindings :: Map VName (Entry lore)
bindings = (Entry lore -> Entry lore)
-> [VName] -> Map VName (Entry lore) -> Map VName (Entry lore)
forall k v. Ord k => (v -> v) -> [k] -> Map k v -> Map k v
adjustSeveral Entry lore -> Entry lore
update [VName]
inedges (Map VName (Entry lore) -> Map VName (Entry lore))
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable' }
          where inedges :: [VName]
inedges = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> SymbolTable lore -> Names
forall lore. Names -> SymbolTable lore -> Names
expandAliases (PatElemT dec -> Names
forall a. AliasesOf a => a -> Names
Aliases.aliasesOf PatElemT dec
pe) SymbolTable lore
vtable'
                update :: Entry lore -> Entry lore
update Entry lore
e = Entry lore
e { entryType :: EntryType lore
entryType = EntryType lore -> EntryType lore
update' (EntryType lore -> EntryType lore)
-> EntryType lore -> EntryType lore
forall a b. (a -> b) -> a -> b
$ Entry lore -> EntryType lore
forall lore. Entry lore -> EntryType lore
entryType Entry lore
e }
                update' :: EntryType lore -> EntryType lore
update' (LetBound LetBoundEntry lore
entry) =
                  LetBoundEntry lore -> EntryType lore
forall lore. LetBoundEntry lore -> EntryType lore
LetBound LetBoundEntry lore
entry
                  { letBoundAliases :: Names
letBoundAliases = VName -> Names
oneName (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> LetBoundEntry lore -> Names
forall lore. LetBoundEntry lore -> Names
letBoundAliases LetBoundEntry lore
entry }
                update' (FParam FParamEntry lore
entry) =
                  FParamEntry lore -> EntryType lore
forall lore. FParamEntry lore -> EntryType lore
FParam FParamEntry lore
entry
                  { fparamAliases :: Names
fparamAliases = VName -> Names
oneName (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe) Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FParamEntry lore -> Names
forall lore. FParamEntry lore -> Names
fparamAliases FParamEntry lore
entry }
                update' EntryType lore
e = EntryType lore
e

insertStms :: (ASTLore lore, IndexOp (Op lore), Aliases.Aliased lore) =>
              Stms lore -> SymbolTable lore -> SymbolTable lore
insertStms :: Stms lore -> SymbolTable lore -> SymbolTable lore
insertStms Stms lore
stms SymbolTable lore
vtable = (SymbolTable lore -> Stm lore -> SymbolTable lore)
-> SymbolTable lore -> [Stm lore] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Stm lore -> SymbolTable lore -> SymbolTable lore)
-> SymbolTable lore -> Stm lore -> SymbolTable lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip Stm lore -> SymbolTable lore -> SymbolTable lore
forall lore.
(ASTLore lore, IndexOp (Op lore), Aliased lore) =>
Stm lore -> SymbolTable lore -> SymbolTable lore
insertStm) SymbolTable lore
vtable ([Stm lore] -> SymbolTable lore) -> [Stm lore] -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
stms

expandAliases :: Names -> SymbolTable lore -> Names
expandAliases :: Names -> SymbolTable lore -> Names
expandAliases Names
names SymbolTable lore
vtable = Names
names Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliasesOfAliases
  where aliasesOfAliases :: Names
aliasesOfAliases =
          [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> (Names -> [Names]) -> Names -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SymbolTable lore -> Names
forall lore. VName -> SymbolTable lore -> Names
`lookupAliases` SymbolTable lore
vtable) ([VName] -> [Names]) -> (Names -> [VName]) -> Names -> [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> Names) -> Names -> Names
forall a b. (a -> b) -> a -> b
$ Names
names

insertFParam :: ASTLore lore =>
                AST.FParam lore -> SymbolTable lore -> SymbolTable lore
insertFParam :: FParam lore -> SymbolTable lore -> SymbolTable lore
insertFParam FParam lore
fparam = VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
insertEntry VName
name EntryType lore
entry
  where name :: VName
name = FParam lore -> VName
forall dec. Param dec -> VName
AST.paramName FParam lore
fparam
        entry :: EntryType lore
entry = FParamEntry lore -> EntryType lore
forall lore. FParamEntry lore -> EntryType lore
FParam FParamEntry :: forall lore. FParamInfo lore -> Names -> FParamEntry lore
FParamEntry { fparamDec :: FParamInfo lore
fparamDec = FParam lore -> FParamInfo lore
forall dec. Param dec -> dec
AST.paramDec FParam lore
fparam
                                   , fparamAliases :: Names
fparamAliases = Names
forall a. Monoid a => a
mempty
                                   }

insertFParams :: ASTLore lore =>
                 [AST.FParam lore] -> SymbolTable lore -> SymbolTable lore
insertFParams :: [FParam lore] -> SymbolTable lore -> SymbolTable lore
insertFParams [FParam lore]
fparams SymbolTable lore
symtable = (SymbolTable lore -> FParam lore -> SymbolTable lore)
-> SymbolTable lore -> [FParam lore] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((FParam lore -> SymbolTable lore -> SymbolTable lore)
-> SymbolTable lore -> FParam lore -> SymbolTable lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip FParam lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
FParam lore -> SymbolTable lore -> SymbolTable lore
insertFParam) SymbolTable lore
symtable [FParam lore]
fparams

insertLParam :: ASTLore lore => LParam lore -> SymbolTable lore -> SymbolTable lore
insertLParam :: LParam lore -> SymbolTable lore -> SymbolTable lore
insertLParam LParam lore
param = VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
insertEntry VName
name EntryType lore
bind
  where bind :: EntryType lore
bind = LParamEntry lore -> EntryType lore
forall lore. LParamEntry lore -> EntryType lore
LParam LParamEntry :: forall lore. LParamInfo lore -> IndexArray -> LParamEntry lore
LParamEntry { lparamDec :: LParamInfo lore
lparamDec = LParam lore -> LParamInfo lore
forall dec. Param dec -> dec
AST.paramDec LParam lore
param
                                  , lparamIndex :: IndexArray
lparamIndex = Maybe Indexed -> IndexArray
forall a b. a -> b -> a
const Maybe Indexed
forall a. Maybe a
Nothing
                                  }
        name :: VName
name = LParam lore -> VName
forall dec. Param dec -> VName
AST.paramName LParam lore
param

insertLoopVar :: ASTLore lore => VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
insertLoopVar :: VName -> IntType -> SubExp -> SymbolTable lore -> SymbolTable lore
insertLoopVar VName
name IntType
it SubExp
bound = VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
insertEntry VName
name EntryType lore
bind
  where bind :: EntryType lore
bind = LoopVarEntry lore -> EntryType lore
forall lore. LoopVarEntry lore -> EntryType lore
LoopVar LoopVarEntry :: forall lore. IntType -> SubExp -> LoopVarEntry lore
LoopVarEntry {
            loopVarType :: IntType
loopVarType = IntType
it
          , loopVarBound :: SubExp
loopVarBound = SubExp
bound
          }

insertFreeVar :: ASTLore lore => VName -> NameInfo lore -> SymbolTable lore -> SymbolTable lore
insertFreeVar :: VName -> NameInfo lore -> SymbolTable lore -> SymbolTable lore
insertFreeVar VName
name NameInfo lore
dec = VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
forall lore.
ASTLore lore =>
VName -> EntryType lore -> SymbolTable lore -> SymbolTable lore
insertEntry VName
name EntryType lore
entry
  where entry :: EntryType lore
entry = FreeVarEntry lore -> EntryType lore
forall lore. FreeVarEntry lore -> EntryType lore
FreeVar FreeVarEntry :: forall lore.
NameInfo lore -> (VName -> IndexArray) -> FreeVarEntry lore
FreeVarEntry {
            freeVarDec :: NameInfo lore
freeVarDec = NameInfo lore
dec
          , freeVarIndex :: VName -> IndexArray
freeVarIndex  = \VName
_ [PrimExp VName]
_ -> Maybe Indexed
forall a. Maybe a
Nothing
          }

consume :: VName -> SymbolTable lore -> SymbolTable lore
consume :: VName -> SymbolTable lore -> SymbolTable lore
consume VName
consumee SymbolTable lore
vtable = (SymbolTable lore -> VName -> SymbolTable lore)
-> SymbolTable lore -> [VName] -> SymbolTable lore
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' SymbolTable lore -> VName -> SymbolTable lore
forall lore. SymbolTable lore -> VName -> SymbolTable lore
consume' SymbolTable lore
vtable ([VName] -> SymbolTable lore) -> [VName] -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$
                          Names -> SymbolTable lore -> Names
forall lore. Names -> SymbolTable lore -> Names
expandAliases (VName -> Names
oneName VName
consumee) SymbolTable lore
vtable
  where consume' :: SymbolTable lore -> VName -> SymbolTable lore
consume' SymbolTable lore
vtable' VName
v =
          SymbolTable lore
vtable' { bindings :: Map VName (Entry lore)
bindings = (Entry lore -> Entry lore)
-> VName -> Map VName (Entry lore) -> Map VName (Entry lore)
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
M.adjust Entry lore -> Entry lore
forall lore. Entry lore -> Entry lore
consume'' VName
v (Map VName (Entry lore) -> Map VName (Entry lore))
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable' }
        consume'' :: Entry lore -> Entry lore
consume'' Entry lore
e = Entry lore
e { entryConsumed :: Bool
entryConsumed = Bool
True }

-- | Hide definitions of those entries that satisfy some predicate.
hideIf :: (Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore
hideIf :: (Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore
hideIf Entry lore -> Bool
hide SymbolTable lore
vtable = SymbolTable lore
vtable { bindings :: Map VName (Entry lore)
bindings = (Entry lore -> Entry lore)
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map Entry lore -> Entry lore
maybeHide (Map VName (Entry lore) -> Map VName (Entry lore))
-> Map VName (Entry lore) -> Map VName (Entry lore)
forall a b. (a -> b) -> a -> b
$ SymbolTable lore -> Map VName (Entry lore)
forall lore. SymbolTable lore -> Map VName (Entry lore)
bindings SymbolTable lore
vtable }
  where maybeHide :: Entry lore -> Entry lore
maybeHide Entry lore
entry
          | Entry lore -> Bool
hide Entry lore
entry = Entry lore
entry { entryType :: EntryType lore
entryType =
                                   FreeVarEntry lore -> EntryType lore
forall lore. FreeVarEntry lore -> EntryType lore
FreeVar FreeVarEntry :: forall lore.
NameInfo lore -> (VName -> IndexArray) -> FreeVarEntry lore
FreeVarEntry { freeVarDec :: NameInfo lore
freeVarDec = Entry lore -> NameInfo lore
forall lore. Entry lore -> NameInfo lore
entryInfo Entry lore
entry
                                                        , freeVarIndex :: VName -> IndexArray
freeVarIndex = \VName
_ [PrimExp VName]
_ -> Maybe Indexed
forall a. Maybe a
Nothing
                                                        }
                               }
          | Bool
otherwise = Entry lore
entry

-- | Hide these definitions, if they are protected by certificates in
-- the set of names.
hideCertified :: Names -> SymbolTable lore -> SymbolTable lore
hideCertified :: Names -> SymbolTable lore -> SymbolTable lore
hideCertified Names
to_hide = (Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore
forall lore.
(Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore
hideIf ((Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore)
-> (Entry lore -> Bool) -> SymbolTable lore -> SymbolTable lore
forall a b. (a -> b) -> a -> b
$ Bool -> (Stm lore -> Bool) -> Maybe (Stm lore) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Stm lore -> Bool
hide (Maybe (Stm lore) -> Bool)
-> (Entry lore -> Maybe (Stm lore)) -> Entry lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry lore -> Maybe (Stm lore)
forall lore. Entry lore -> Maybe (Stm lore)
entryStm
  where hide :: Stm lore -> Bool
hide = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
to_hide) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certificates -> [VName]
unCertificates (Certificates -> [VName])
-> (Stm lore -> Certificates) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> Certificates
forall lore. Stm lore -> Certificates
stmCerts