{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLocation (..),
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    copyElementWise,
    typeSize,
    isMapTransposeCopy,

    -- * Constructing code.
    dLParams,
    dFParams,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    (<--),
    (<~~),
    function,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find, genericLength, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    bytes,
    elements,
    withElemType,
  )
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings

-- | How to compile an t'Op'.
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()

-- | How to compile some 'Stms'.
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()

-- | How to compile an 'Exp'.
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()

type CopyCompiler lore r op =
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  ImpM lore r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler lore r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM lore r op ()

data Operations lore r op = Operations
  { Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op,
    Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op,
    Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op,
    Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op,
    Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem lore, FreeIn op) =>
  OpCompiler lore r op ->
  Operations lore r op
defaultOperations :: OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc =
  Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations
    { opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc,
      opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy,
      opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is dared, this is where it is stored.
data MemLocation = MemLocation
  { MemLocation -> VName
memLocationName :: VName,
    MemLocation -> [DimSize]
memLocationShape :: [Imp.DimSize],
    MemLocation -> IxFun (TExp Int64)
memLocationIxFun :: IxFun.IxFun (Imp.TExp Int64)
  }
  deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> String
(Int -> MemLocation -> ShowS)
-> (MemLocation -> String)
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> String
$cshow :: MemLocation -> String
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> String
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> String)
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> String
$cshow :: ArrayEntry -> String
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [DimSize]
entryArrayShape = MemLocation -> [DimSize]
memLocationShape (MemLocation -> [DimSize])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> String
(Int -> MemEntry -> ShowS)
-> (MemEntry -> String) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> String
$cshow :: MemEntry -> String
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> String
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> String)
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> String
$cshow :: ScalarEntry -> String
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore
  = ArrayVar (Maybe (Exp lore)) ArrayEntry
  | ScalarVar (Maybe (Exp lore)) ScalarEntry
  | MemVar (Maybe (Exp lore)) MemEntry
  deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> String
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> String)
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
forall lore. Decorations lore => [VarEntry lore] -> ShowS
forall lore. Decorations lore => VarEntry lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> String
$cshow :: forall lore. Decorations lore => VarEntry lore -> String
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
Show)

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination
  { Destination -> Maybe Int
destinationTag :: Maybe Int,
    Destination -> [ValueDestination]
valueDestinations :: [ValueDestination]
  }
  deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> String
(Int -> Destination -> ShowS)
-> (Destination -> String)
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> String
$cshow :: Destination -> String
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLocation' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLocation)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> String
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> String)
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> String
$cshow :: ValueDestination -> String
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env lore r op = Env
  { Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op,
    Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op,
    Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op,
    Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op,
    Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op),
    Env lore r op -> Space
envDefaultSpace :: Imp.Space,
    Env lore r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    Env lore r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    Env lore r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    Env lore r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
  Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env lore r op
Env
    { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
      envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
      envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
      envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable lore = M.Map VName (VarEntry lore)

data ImpState lore r op = ImpState
  { ImpState lore r op -> VTable lore
stateVTable :: VTable lore,
    ImpState lore r op -> Functions op
stateFunctions :: Imp.Functions op,
    ImpState lore r op -> Code op
stateCode :: Imp.Code op,
    ImpState lore r op -> Warnings
stateWarnings :: Warnings,
    ImpState lore r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState lore r op
newState :: VNameSource -> ImpState lore r op
newState = VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
ImpState VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty

newtype ImpM lore r op a
  = ImpM (ReaderT (Env lore r op) (State (ImpState lore r op)) a)
  deriving
    ( a -> ImpM lore r op b -> ImpM lore r op a
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
(forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor,
      Functor (ImpM lore r op)
a -> ImpM lore r op a
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
    ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: (a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
$cp1Applicative :: forall lore r op. Functor (ImpM lore r op)
Applicative,
      Applicative (ImpM lore r op)
a -> ImpM lore r op a
Applicative (ImpM lore r op)
-> (forall a b.
    ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$cp1Monad :: forall lore r op. Applicative (ImpM lore r op)
Monad,
      MonadState (ImpState lore r op),
      MonadReader (Env lore r op)
    )

instance MonadFreshNames (ImpM lore r op) where
  getNameSource :: ImpM lore r op VNameSource
getNameSource = (ImpState lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM lore r op) where
  askScope :: ImpM lore r op (Scope SOACS)
askScope = (ImpState lore r op -> Scope SOACS) -> ImpM lore r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Scope SOACS)
 -> ImpM lore r op (Scope SOACS))
-> (ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall lore. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
    where
      entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
        PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape ([DimSize] -> ShapeBase DimSize) -> [DimSize] -> ShapeBase DimSize
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry

runImpM ::
  ImpM lore r op a ->
  r ->
  Operations lore r op ->
  Imp.Space ->
  ImpState lore r op ->
  (a, ImpState lore r op)
runImpM :: ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (ImpM ReaderT (Env lore r op) (State (ImpState lore r op)) a
m) r
r Operations lore r op
ops Space
space = State (ImpState lore r op) a
-> ImpState lore r op -> (a, ImpState lore r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r op) (State (ImpState lore r op)) a
-> Env lore r op -> State (ImpState lore r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r op) (State (ImpState lore r op)) a
m (Env lore r op -> State (ImpState lore r op) a)
-> Env lore r op -> State (ImpState lore r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations lore r' op' ->
  ImpM lore r' op' a ->
  ImpM lore r op (Imp.Code op')
subImpM_ :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m

subImpM ::
  r' ->
  Operations lore r' op' ->
  ImpM lore r' op' a ->
  ImpM lore r op (a, Imp.Code op')
subImpM :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m) = do
  Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState lore r op
s <- ImpM lore r op (ImpState lore r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env lore r' op'
env' =
        Env lore r op
env
          { envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops,
            envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops,
            envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops,
            envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState lore r' op'
s' =
        ImpState :: forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
ImpState
          { stateVTable :: VTable lore
stateVTable = ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty
          }
      (a
x, ImpState lore r' op'
s'') = State (ImpState lore r' op') a
-> ImpState lore r' op' -> (a, ImpState lore r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
-> Env lore r' op' -> State (ImpState lore r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m Env lore r' op'
env') ImpState lore r' op'
s'

  VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r' op'
s''
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r' op'
s''
  (a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState lore r' op' -> Code op'
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: ImpM lore r op () -> ImpM lore r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM lore r op ((), Code op) -> ImpM lore r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM lore r op ((), Code op) -> ImpM lore r op (Code op))
-> (ImpM lore r op () -> ImpM lore r op ((), Code op))
-> ImpM lore r op ()
-> ImpM lore r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op () -> ImpM lore r op ((), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect'

collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = do
  Code op
prev_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
  a
x <- ImpM lore r op a
m
  Code op
new_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  (a, Code op) -> ImpM lore r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment String
desc ImpM lore r op ()
m = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM lore r op ()
emit :: Code op -> ImpM lore r op ()
emit Code op
code = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM lore r op ()
warnings :: Warnings -> ImpM lore r op ()
warnings Warnings
ws = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> String -> ImpM lore r op ()
warn :: loc -> [loc] -> String -> ImpM lore r op ()
warn loc
loc [loc]
locs String
problem =
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> String -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) String
problem

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState lore r op -> Functions op)
-> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: Name -> ImpM lore r op Bool
hasFunction Name
fname = (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Bool) -> ImpM lore r op Bool)
-> (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem lore => Stms lore -> VTable lore
constsVTable :: Stms lore -> VTable lore
constsVTable = (Stm lore -> VTable lore) -> Stms lore -> VTable lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> VTable lore
forall lore.
(LetDec lore ~ MemBound NoUniqueness) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
  where
    stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
      (PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore))
-> [PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
forall lore.
Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
peVtable Exp lore
e) ([PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore))
-> [PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat
    peVtable :: Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
peVtable Exp lore
e (PatElem VName
name MemBound NoUniqueness
dec) =
      VName -> VarEntry lore -> Map VName (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry lore -> Map VName (VarEntry lore))
-> VarEntry lore -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) MemBound NoUniqueness
dec

compileProg ::
  (Mem lore, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations lore r op ->
  Imp.Space ->
  Prog lore ->
  m (Warnings, Imp.Definitions op)
compileProg :: r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState lore r op]
ss) =
          [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState lore r op)] -> ([()], [ImpState lore r op]))
-> [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState lore r op)
-> (FunDef lore -> ((), ImpState lore r op))
-> [FunDef lore]
-> [((), ImpState lore r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState lore r op)
forall a. Strategy a
rpar (VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src) [FunDef lore]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
        (Constants op
consts', ImpState lore r op
s') =
          ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (Constants op, ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts) r
r Operations lore r op
ops Space
space (ImpState lore r op -> (Constants op, ImpState lore r op))
-> ImpState lore r op -> (Constants op, ImpState lore r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState lore r op] -> ImpState lore r op
forall lore r op lore r. [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss
     in ( ( ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s',
            Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s')
          ),
          ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src FunDef lore
fdef =
      ImpM lore r op ()
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> ((), ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM
        (FunDef lore -> ImpM lore r op ()
forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef)
        r
r
        Operations lore r op
ops
        Space
space
        (VNameSource -> ImpState Any Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) {stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore. Mem lore => Stms lore -> VTable lore
constsVTable Stms lore
consts}

    combineStates :: [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState lore r op -> VNameSource)
-> [ImpState lore r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource [ImpState lore r op]
ss)
       in (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Warnings)
-> [ImpState lore r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings [ImpState lore r op]
ss
            }

compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where
    -- Fish out those top-level declarations in the constant
    -- initialisation code that are free in the functions.
    extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
      Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
    extract (Imp.DeclareMem VName
name Space
space)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
          Code op
forall a. Monoid a => a
mempty
        )
    extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
          Code op
forall a. Monoid a => a
mempty
        )
    extract Code op
s =
      (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam ::
  Mem lore =>
  FParam lore ->
  ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param (MemInfo DimSize Uniqueness MemBind)
-> MemInfo DimSize Uniqueness MemBind
forall dec. Param dec -> dec
paramDec FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt ShapeBase DimSize
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$
      ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$
        VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
          VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
  where
    name :: VName
name = Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

fparamSizes :: Typed dec => Param dec -> S.Set VName
fparamSizes :: Param dec -> Set VName
fparamSizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (Param dec -> [VName]) -> Param dec -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimSize] -> [VName]
subExpVars ([DimSize] -> [VName])
-> (Param dec -> [DimSize]) -> Param dec -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> (Param dec -> Type) -> Param dec -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType

compileInParams ::
  Mem lore =>
  [FParam lore] ->
  [EntryPointType] ->
  ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
  let ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params, [Param (MemInfo DimSize Uniqueness MemBind)]
val_params) =
        Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (MemInfo DimSize Uniqueness MemBind)
 -> ImpM lore r op (Either Param ArrayDecl))
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo DimSize Uniqueness MemBind)
-> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds
      sizes :: Set VName
sizes = [Set VName] -> Set VName
forall a. Monoid a => [a] -> a
mconcat ([Set VName] -> Set VName) -> [Set VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind) -> Set VName)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo DimSize Uniqueness MemBind) -> Set VName
forall dec. Typed dec => Param dec -> Set VName
fparamSizes ([Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName])
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo DimSize Uniqueness MemBind)]
val_params

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind)
 -> Maybe (VName, Space))
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param (MemInfo DimSize Uniqueness MemBind) -> Maybe (VName, Space)
forall d u ret. Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
              (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
              Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam, Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [DimSize]
shape IxFun (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [DimSize]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt)
            | Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes ->
              Maybe ValueDesc
forall a. Maybe a
Nothing
            | Bool
otherwise ->
              ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts (TypeOpaque String
desc Int
n : [EntryPointType]
epts) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams =
        let ([Param (MemInfo DimSize Uniqueness MemBind)]
fparams', [Param (MemInfo DimSize Uniqueness MemBind)]
rest) = Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
         in String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
              String
desc
              ((Param (MemInfo DimSize Uniqueness MemBind) -> Maybe ValueDesc)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
            [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
rest
      mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam : [Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeUnsigned)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam : [Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeDirect)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts [EntryPointType]
_ [Param (MemInfo DimSize Uniqueness MemBind)]
_ = []

  ([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParams ::
  Mem lore =>
  [RetType lore] ->
  [EntryPointType] ->
  ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
  (([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams, Map Int ValueDestination
ctx_dests)) <-
    WriterT
  ([Param], Map Int ValueDestination)
  (ImpM lore r op)
  ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param], Map Int ValueDestination)
   (ImpM lore r op)
   ([ExternalValue], [ValueDestination])
 -> ImpM
      lore
      r
      op
      (([ExternalValue], [ValueDestination]),
       ([Param], Map Int ValueDestination)))
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  ([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[MemInfo (Ext DimSize) Uniqueness MemReturn]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
  let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
  ([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
  where
    imp :: ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      a)
-> (ImpM lore r op a
    -> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

    mkExts :: [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts (TypeOpaque String
desc Int
n : [EntryPointType]
epts) [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts = do
      let ([MemInfo (Ext DimSize) Uniqueness MemReturn]
rts', [MemInfo (Ext DimSize) Uniqueness MemReturn]
rest) = Int
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> ([MemInfo (Ext DimSize) Uniqueness MemReturn],
    [MemInfo (Ext DimSize) Uniqueness MemReturn])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
      ([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
 -> ([ValueDesc], [ValueDestination]))
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MemInfo (Ext DimSize) Uniqueness MemReturn
 -> Signedness
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      (ValueDesc, ValueDestination))
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> [Signedness]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rest
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue String
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          [ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests
        )
    mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (MemInfo (Ext DimSize) Uniqueness MemReturn
rt : [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemInfo (Ext DimSize) Uniqueness MemReturn
rt Signedness
Imp.TypeUnsigned
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (MemInfo (Ext DimSize) Uniqueness MemReturn
rt : [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemInfo (Ext DimSize) Uniqueness MemReturn
rt Signedness
Imp.TypeDirect
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts [EntryPointType]
_ [MemInfo (Ext DimSize) Uniqueness MemReturn]
_ = ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])

    mkParam :: MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemMem {} Signedness
_ =
      String
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => String -> a
error String
"Functions may not explicitly return memory blocks."
    mkParam (MemPrim PrimType
t) Signedness
ept = do
      VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"scalar_out"
      ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
    mkParam (MemArray PrimType
t ShapeBase (Ext DimSize)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
      Space
space <- (Env lore r op -> Space)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
      VName
memout <- case MemReturn
dec of
        ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
          VName
memout <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_mem"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout
            )
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
        ReturnsInBlock VName
memout ExtIxFun
_ ->
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
      [DimSize]
resultshape <- (Ext DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> [Ext DimSize]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [DimSize]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
inspectExtSize ([Ext DimSize]
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      [DimSize])
-> [Ext DimSize]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext DimSize) -> [Ext DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext DimSize)
shape
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [DimSize]
resultshape,
          Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        )

    inspectExtSize :: Ext DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
inspectExtSize (Ext Int
x) = do
      (Map Any Any
memseen, Map Int VName
arrseen) <- StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  (Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
      case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
        Maybe VName
Nothing -> do
          VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_arrsize"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int64],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out
            )
          (Map Any Any, Map Int VName)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
          DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
        Just VName
out ->
          DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
    inspectExtSize (Free DimSize
se) =
      DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return DimSize
se

compileFunDef ::
  Mem lore =>
  FunDef lore ->
  ImpM lore r op ()
compileFunDef :: FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) =
  (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore
     r
     op
     (([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
    Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function (Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust Maybe EntryPoint
entry) [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
  where
    params_entry :: [EntryPointType]
params_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> a
fst Maybe EntryPoint
entry
    ret_entry :: [EntryPointType]
ret_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([MemInfo (Ext DimSize) Uniqueness MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[MemInfo (Ext DimSize) Uniqueness MemReturn]
rettype) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> b
snd Maybe EntryPoint
entry
    compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
      ([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
      [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
      [ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec lore
_ Stms lore
stms [DimSize]
ses = BodyT lore
body
      Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

      ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)

compileBody :: (Mem lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) = do
  Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param dec]
params (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) =
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, DimSize)]
-> ((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [DimSize] -> [(Param dec, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [DimSize]
ses) (((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, DimSize
se) -> VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] DimSize
se []

compileLoopBody :: Typed dec => [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody :: [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM lore r op VName)
-> [Param dec] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName)
-> (Param dec -> String) -> Param dec -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_tmp") ShowS -> (Param dec -> String) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString (VName -> String) -> (Param dec -> VName) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM lore r op ()]
copy_to_merge_params <- [(Param dec, VName, DimSize)]
-> ((Param dec, VName, DimSize)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec]
-> [VName] -> [DimSize] -> [(Param dec, VName, DimSize)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [DimSize]
ses) (((Param dec, VName, DimSize)
  -> ImpM lore r op (ImpM lore r op ()))
 -> ImpM lore r op [ImpM lore r op ()])
-> ((Param dec, VName, DimSize)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, DimSize
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- DimSize
se -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params

compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb <- (Env lore r op
 -> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
-> ImpM
     lore
     r
     op
     (Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m

defCompileStms ::
  (Mem lore, FreeIn op) =>
  Names ->
  Stms lore ->
  ImpM lore r op () ->
  ImpM lore r op ()
defCompileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpDec lore)
aux Exp lore
e : [Stm lore]
bs) = do
      Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat)

      Code op
e_code <-
        Attrs -> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs (StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux) (ImpM lore r op (Code op) -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
          ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
      (Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT (MemBound NoUniqueness) -> Set (VName, Space)
patternAllocs Pattern lore
PatternT (MemBound NoUniqueness)
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
              Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
          to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
      ((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code

      Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
      Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: PatternT (MemBound NoUniqueness) -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT (MemBound NoUniqueness) -> [(VName, Space)])
-> PatternT (MemBound NoUniqueness)
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemBound NoUniqueness) -> Maybe (VName, Space))
-> [PatElemT (MemBound NoUniqueness)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT (MemBound NoUniqueness) -> Maybe (VName, Space)
forall dec. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT (MemBound NoUniqueness)] -> [(VName, Space)])
-> (PatternT (MemBound NoUniqueness)
    -> [PatElemT (MemBound NoUniqueness)])
-> PatternT (MemBound NoUniqueness)
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
    isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e

defCompileExp ::
  (Mem lore) =>
  Pattern lore ->
  Exp lore ->
  ImpM lore r op ()
defCompileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If DimSize
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
  TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (DimSize -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp DimSize
cond) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch)
defCompileExp Pattern lore
pat (Apply Name
fname [(DimSize, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimSize, Diet) -> ImpM lore r op (Maybe Arg))
-> [(DimSize, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimSize, Diet) -> ImpM lore r op (Maybe Arg)
forall (m :: * -> *) t b.
(Monad m, HasScope t m) =>
(DimSize, b) -> m (Maybe Arg)
compileArg [(DimSize, Diet)]
args
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (DimSize, b) -> m (Maybe Arg)
compileArg (DimSize
se, b
_) = do
      Type
t <- DimSize -> m Type
forall t (m :: * -> *). HasScope t m => DimSize -> m Type
subExpType DimSize
se
      case (DimSize
se, Type
t) of
        (DimSize
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (DimSize, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pattern lore
pat (BasicOp BasicOp
op) = Pattern lore -> BasicOp -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp
op
defCompileExp Pattern lore
pat (DoLoop [(FParam lore, DimSize)]
ctx [(FParam lore, DimSize)]
val LoopForm lore
form BodyT lore
body) = do
  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> String -> ImpM lore r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] String
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
mergepat
  [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge (((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
  -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo DimSize Uniqueness MemBind)
p, DimSize
se) ->
    Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
p) [] DimSize
se []

  let doBody :: ImpM lore r op ()
doBody = [Param (MemInfo DimSize Uniqueness MemBind)]
-> BodyT lore -> ImpM lore r op ()
forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat BodyT lore
body

  case LoopForm lore
form of
    ForLoop VName
i IntType
_ DimSize
bound [(LParam lore, VName)]
loopvars -> do
      let setLoopParam :: (Param (MemBound NoUniqueness), VName) -> ImpM lore r op ()
setLoopParam (Param (MemBound NoUniqueness)
p, VName
a)
            | Prim PrimType
_ <- Param (MemBound NoUniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemBound NoUniqueness)
p =
              VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param (MemBound NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (MemBound NoUniqueness)
p) [] (VName -> DimSize
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
Imp.vi64 VName
i]
            | Bool
otherwise =
              () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      Exp
bound' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
bound

      [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param (MemBound NoUniqueness), VName)
 -> Param (MemBound NoUniqueness))
-> [(Param (MemBound NoUniqueness), VName)]
-> [Param (MemBound NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemBound NoUniqueness), VName)
-> Param (MemBound NoUniqueness)
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars
      VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound' (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param (MemBound NoUniqueness), VName) -> ImpM lore r op ())
-> [(Param (MemBound NoUniqueness), VName)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param (MemBound NoUniqueness), VName) -> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody

  Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([DimSize] -> [(ValueDestination, DimSize)])
-> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> DimSize)
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> DimSize
Var (VName -> DimSize)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> VName)
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName (Param (MemInfo DimSize Uniqueness MemBind) -> VName)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> Param (MemInfo DimSize Uniqueness MemBind))
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
r []
  where
    merge :: [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge = [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
ctx [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
val
    mergepat :: [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat = ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
 -> Param (MemInfo DimSize Uniqueness MemBind))
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge
defCompileExp Pattern lore
pat (Op Op lore
op) = do
  PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
 -> PatternT (MemBound NoUniqueness)
 -> Op lore
 -> ImpM lore r op ())
-> ImpM
     lore
     r
     op
     (PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
  PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT (MemBound NoUniqueness)
pat Op lore
op

defCompileBasicOp ::
  Mem lore =>
  Pattern lore ->
  BasicOp ->
  ImpM lore r op ()
defCompileBasicOp :: Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (SubExp DimSize
se) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] DimSize
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Opaque DimSize
se) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] DimSize
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (UnOp UnOp
op DimSize
e) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (ConvOp ConvOp
conv DimSize
e) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (BinOp BinOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (CmpOp CmpOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp Pattern lore
_ (Assert DimSize
e ErrorMsg DimSize
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  ErrorMsg Exp
msg' <- (DimSize -> ImpM lore r op Exp)
-> ErrorMsg DimSize -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg DimSize
msg
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ())
-> (SrcLoc, [SrcLoc]) -> String -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> String -> ImpM lore r op ()
warn (SrcLoc, [SrcLoc])
loc String
"Safety check required at run-time."
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Index VName
src Slice DimSize
slice)
  | Just [DimSize]
idxs <- Slice DimSize -> Maybe [DimSize]
forall d. Slice d -> Maybe [d]
sliceIndices Slice DimSize
slice =
    VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM lore r op ())
-> [DimIndex (TExp Int64)] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (DimSize -> DimIndex (TExp Int64))
-> [DimSize] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (DimSize -> TExp Int64) -> DimSize -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [DimSize]
idxs
defCompileBasicOp Pattern lore
_ Index {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Update VName
_ Slice DimSize
slice DimSize
se) =
  VName -> [DimIndex (TExp Int64)] -> DimSize -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex (TExp Int64)] -> DimSize -> ImpM lore r op ()
sUpdate (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) ((DimIndex DimSize -> DimIndex (TExp Int64))
-> Slice DimSize -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((DimSize -> TExp Int64)
-> DimIndex DimSize -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice DimSize
slice) DimSize
se
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Replicate (Shape [DimSize]
ds) DimSize
se) = do
  [Exp]
ds' <- (DimSize -> ImpM lore r op Exp)
-> [DimSize] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [DimSize]
ds
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
  Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
Imp.vi64) [VName]
is) DimSize
se []
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp Pattern lore
_ Scratch {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (Iota DimSize
n DimSize
e DimSize
s IntType
it) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  Exp
s' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
s
  String
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
n) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      String -> TExp Any -> ImpM lore r op (TV Any)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"x" (TExp Any -> ImpM lore r op (TV Any))
-> TExp Any -> ImpM lore r op (TV Any)
forall a b. (a -> b) -> a -> b
$
        Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> DimSize
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Copy VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Concat Int
i VName
x [VName]
ys DimSize
_) = do
  TV Int64
offs_glb <- String -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"tmp_offs" TExp Int64
0

  [VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [DimSize]
y_dims <- Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> ImpM lore r op Type -> ImpM lore r op [DimSize]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
drop Int
i [DimSize]
y_dims of
          [] -> String -> TExp Int64
forall a. HasCallStack => String -> a
error (String -> TExp Int64) -> String -> TExp Int64
forall a b. (a -> b) -> a -> b
$ String
"defCompileBasicOp Concat: empty array shape for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
y
          DimSize
r : [DimSize]
_ -> DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
r
        skip_dims :: [DimSize]
skip_dims = Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
take Int
i [DimSize]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (DimSize -> DimIndex (TExp Int64))
-> [DimSize] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (DimSize -> TExp Int64) -> DimSize -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [DimSize]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [DimIndex (TExp Int64)]
destslice (VName -> DimSize
Var VName
y) []
    TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (ArrayLit [DimSize]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (DimSize -> Maybe PrimValue) -> [DimSize] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> Maybe PrimValue
isLiteral [DimSize]
es = do
    MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe)
    Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
    let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
    VName
static_array <- String -> ImpM lore r op VName
forall lore r op. String -> ImpM lore r op VName
newVNameForFun String
"static_array"
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
    let static_src :: MemLocation
static_src =
          VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> DimSize
intConst IntType
Int64 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es] (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$
            Shape (TExp Int64) -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es]
        entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
    VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
    let slice :: [DimIndex (TExp Int64)]
slice = [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
0 ([DimSize] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [DimSize]
es) TExp Int64
1]
    CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem [DimIndex (TExp Int64)]
slice MemLocation
static_src [DimIndex (TExp Int64)]
slice
  | Bool
otherwise =
    [(Integer, DimSize)]
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [DimSize] -> [(Integer, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [DimSize]
es) (((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, DimSize
e) ->
      VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] DimSize
e []
  where
    isLiteral :: DimSize -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral DimSize
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp Pattern lore
_ Rearrange {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pattern lore
_ Rotate {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pattern lore
_ Reshape {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pattern lore
pat BasicOp
e =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    String
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT (MemBound NoUniqueness) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemBound NoUniqueness)
pat
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nfor expression\n  "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> String
forall a. Pretty a => a -> String
pretty BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall lore r op. ArrayDecl -> ImpM lore r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
          Maybe (Exp lore)
forall a. Maybe a
Nothing
          ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
            { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams :: [FParam lore] -> ImpM lore r op ()
addFParams = (Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ())
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ()
forall u lore r op.
Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam
  where
    addFParam :: Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo DimSize u MemBind)
fparam =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo DimSize u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo DimSize u MemBind -> MemBound NoUniqueness)
-> MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize u MemBind) -> MemInfo DimSize u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo DimSize u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  Mem lore =>
  Maybe (Exp lore) ->
  [PatElem lore] ->
  ImpM lore r op ()
dVars :: Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT (MemBound NoUniqueness) -> ImpM lore r op ())
-> [PatElemT (MemBound NoUniqueness)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemBound NoUniqueness) -> ImpM lore r op ()
dVar
  where
    dVar :: PatElemT (MemBound NoUniqueness) -> ImpM lore r op ()
dVar = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Scope lore -> ImpM lore r op ())
-> (PatElemT (MemBound NoUniqueness) -> Scope lore)
-> PatElemT (MemBound NoUniqueness)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (MemBound NoUniqueness) -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatElemT dec -> Scope lore
scopeOfPatElem

dFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams :: [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param (MemInfo DimSize Uniqueness MemBind)] -> Scope lore)
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemInfo DimSize Uniqueness MemBind)] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams

dLParams :: Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams :: [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param (MemBound NoUniqueness)] -> Scope lore)
-> [Param (MemBound NoUniqueness)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemBound NoUniqueness)] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimVol :: String -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol String
name PrimType
t TExp t
e = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim String
name PrimType
t = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM lore r op ()
dPrimV_ :: VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimV :: String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
name TExp t
e = do
  TV t
name' <- String -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM lore r op (Imp.TExp t)
dPrimVE :: String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
name TExp t
e = do
  TV t
name' <- String -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  TExp t -> ImpM lore r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp lore) ->
  MemBound NoUniqueness ->
  VarEntry lore
memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
  Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
  Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
  let location :: MemLocation
location = VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
   in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
        Maybe (Exp lore)
e
        ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  Mem lore =>
  NameInfo lore ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: NameInfo lore -> MemBound NoUniqueness
infoDec (LetName LetDec lore
dec) = LetDec lore
MemBound NoUniqueness
dec
infoDec (FParamName FParamInfo lore
dec) = MemInfo DimSize Uniqueness MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
MemInfo DimSize Uniqueness MemBind
dec
infoDec (LParamName LParamInfo lore
dec) = LParamInfo lore
MemBound NoUniqueness
dec
infoDec (IndexName IntType
it) = PrimType -> MemBound NoUniqueness
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemBound NoUniqueness)
-> PrimType -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  Mem lore =>
  Maybe (Exp lore) ->
  VName ->
  NameInfo lore ->
  ImpM lore r op ()
dInfo :: Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo lore
info = do
  let entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> MemBound NoUniqueness
forall lore. Mem lore => NameInfo lore -> MemBound NoUniqueness
infoDec NameInfo lore
info
  case VarEntry lore
entry of
    MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry

dScope ::
  Mem lore =>
  Maybe (Exp lore) ->
  Scope lore ->
  ImpM lore r op ()
dScope :: Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo lore) -> ImpM lore r op ())
-> [(VName, NameInfo lore)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo lore -> ImpM lore r op ())
 -> (VName, NameInfo lore) -> ImpM lore r op ())
-> (VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo lore)] -> ImpM lore r op ())
-> (Scope lore -> [(VName, NameInfo lore)])
-> Scope lore
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope lore -> [(VName, NameInfo lore)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt ShapeBase DimSize
shape MemBind
membind =
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase DimSize
-> NoUniqueness
-> MemBind
-> MemBound NoUniqueness
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
NoUniqueness MemBind
membind

everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall (m :: * -> *). Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
    funcallTarget (ArrayDestination Maybe MemLocation
_) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: TV t -> DimSize
tvSize = VName -> DimSize
Var (VName -> DimSize) -> (TV t -> VName) -> TV t -> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM lore r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

  toInt32Exp :: a -> Imp.TExp Int32
  toInt32Exp = Exp -> TExp Int32
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int32) -> (a -> Exp) -> a -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32

  toInt64Exp :: a -> Imp.TExp Int64
  toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64

  toBoolExp :: a -> Imp.TExp Bool
  toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool

instance ToExp SubExp where
  toExp :: DimSize -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry lore
_ -> String -> ImpM lore r op Exp
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op Exp) -> String -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ String
"toExp SubExp: SubExp is not a primitive type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v

  toExp' :: PrimType -> DimSize -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
  toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar

addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s}

localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM lore r op (Maybe Name)
askFunction :: ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun String
s = do
  Maybe String
fname <- (Name -> String) -> Maybe Name -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> String
nameToString (Maybe Name -> Maybe String)
-> ImpM lore r op (Maybe Name) -> ImpM lore r op (Maybe String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String -> ShowS -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
".") Maybe String
fname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM lore r op Name
nameForFun :: String -> ImpM lore r op Name
nameForFun String
s = do
  Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM lore r op Name) -> Name -> ImpM lore r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> String -> Name
nameFromString String
s

askEnv :: ImpM lore r op r
askEnv :: ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv

localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM lore r op Attrs
askAttrs :: ImpM lore r op Attrs
askAttrs = (Env lore r op -> Attrs) -> ImpM lore r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs Attrs
attrs = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs Env lore r op
env}

localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations lore r op
ops = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env ->
  Env lore r op
env
    { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
      envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
      envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
      envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Operations lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM lore r op (VTable lore)
getVTable :: ImpM lore r op (VTable lore)
getVTable = (ImpState lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable

putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VTable lore
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: (VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
  VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
  a
a <- ImpM lore r op a
m
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
  a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
  Maybe (VarEntry lore)
res <- (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Maybe (VarEntry lore))
 -> ImpM lore r op (Maybe (VarEntry lore)))
-> (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
  case Maybe (VarEntry lore)
res of
    Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
    Maybe (VarEntry lore)
_ -> String -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (VarEntry lore))
-> String -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry lore
_ -> String -> ImpM lore r op ArrayEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ArrayEntry)
-> String -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupArray: not an array: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry lore
_ -> String -> ImpM lore r op MemEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op MemEntry)
-> String -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ String
"Unknown memory block: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

destinationFromPattern :: Mem lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
  ([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT (MemBound NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT (MemBound NoUniqueness)
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT (MemBound NoUniqueness)]
    -> ImpM lore r op [ValueDestination])
-> [PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemBound NoUniqueness)
 -> ImpM lore r op ValueDestination)
-> [PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemBound NoUniqueness) -> ImpM lore r op ValueDestination
forall dec lore r op.
PatElemT dec -> ImpM lore r op ValueDestination
inspect ([PatElemT (MemBound NoUniqueness)] -> ImpM lore r op Destination)
-> [PatElemT (MemBound NoUniqueness)] -> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
    PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat
  where
    inspect :: PatElemT dec -> ImpM lore r op ValueDestination
inspect PatElemT dec
patElem = do
      let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
      VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
      case VarEntry lore
entry of
        ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation {} PrimType
_) ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name Shape (TExp Int64)
indices = do
  ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
  MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
indices

fullyIndexArray' ::
  MemLocation ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLocation VName
mem [DimSize]
_ IxFun (TExp Int64)
ixfun) Shape (TExp Int64)
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
  let indices' :: Shape (TExp Int64)
indices' = case Space
space of
        ScalarSpace [DimSize]
ds PrimType
_ ->
          let (Shape (TExp Int64)
zero_is, Shape (TExp Int64)
is) = Int
-> Shape (TExp Int64) -> (Shape (TExp Int64), Shape (TExp Int64))
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) Shape (TExp Int64)
indices
           in (TExp Int64 -> TExp Int64)
-> Shape (TExp Int64) -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> TExp Int64
forall a b. a -> b -> a
const TExp Int64
0) Shape (TExp Int64)
zero_is Shape (TExp Int64) -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a] -> [a]
++ Shape (TExp Int64)
is
        Space
_ -> Shape (TExp Int64)
indices
  (VName, Space, Count Elements (TExp Int64))
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( VName
mem,
      Space
space,
      TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun Shape (TExp Int64)
indices'
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler lore r op
copy :: CopyCompiler lore r op
copy PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
  CopyCompiler lore r op
cc PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice

-- | Is this copy really a mapping with transpose?
isMapTransposeCopy ::
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  Maybe
    ( Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64
    )
isMapTransposeCopy :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy
  PrimType
bt
  (MemLocation VName
_ [DimSize]
_ IxFun (TExp Int64)
destIxFun)
  [DimIndex (TExp Int64)]
destslice
  (MemLocation VName
_ [DimSize]
_ IxFun (TExp Int64)
srcIxFun)
  [DimIndex (TExp Int64)]
srcslice
    | Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
destshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
      Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
destshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall b a. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
srcshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
srcshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Bool
otherwise =
      Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
    where
      bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
      swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)

      destIxFun' :: IxFun (TExp Int64)
destIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun [DimIndex (TExp Int64)]
destslice
      srcIxFun' :: IxFun (TExp Int64)
srcIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun [DimIndex (TExp Int64)]
srcslice

      isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
        let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall (t :: * -> *) (t :: * -> *) a b c.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
        (a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( a
dest_offset,
            b
src_offset,
            c
num_arrays,
            d
size_x,
            e
size_y
          )

      getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
        let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
            (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
         in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> String
mapTransposeName PrimType
bt = String
"map_transpose_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt

mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM lore r op Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt

  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler lore r op
defaultCopy :: CopyCompiler lore r op
defaultCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  | Just
      ( TExp Int64
destoffset,
        TExp Int64
srcoffset,
        TExp Int64
num_arrays,
        TExp Int64
size_x,
        TExp Int64
size_y
        ) <-
      PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
    Name
fname <- PrimType -> ImpM lore r op Name
forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
pt
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
        []
        Name
fname
        ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
          PrimType
pt
          VName
destmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          VName
srcmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          TExp Int64
num_arrays
          TExp Int64
size_x
          TExp Int64
size_y
  | Just TExp Int64
destoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
dest_ixfun [DimIndex (TExp Int64)]
destslice) TExp Int64
pt_size,
    Just TExp Int64
srcoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
src_ixfun [DimIndex (TExp Int64)]
srcslice) TExp Int64
pt_size = do
    Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
    Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
    if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
      then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
      else
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
            Space
destspace
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
            Space
srcspace
            (Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
  | Bool
otherwise =
    CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  where
    pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TExp Int64) -> TExp Int64)
-> Shape (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice

    MemLocation VName
destmem [DimSize]
_ IxFun (TExp Int64)
dest_ixfun = MemLocation
dest
    MemLocation VName
srcmem [DimSize]
_ IxFun (TExp Int64)
src_ixfun = MemLocation
src

    isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
    isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler lore r op
copyElementWise :: CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  let bounds :: Shape (TExp Int64)
bounds = [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
bounds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
  let ivars :: Shape (TExp Int64)
ivars = (VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is
  (VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest (Shape (TExp Int64)
 -> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice Shape (TExp Int64)
ivars
  (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
src (Shape (TExp Int64)
 -> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice Shape (TExp Int64)
ivars
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Shape (TExp Int64)
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op (Imp.Code op)
copyArrayDWIM :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLocation
destlocation@(MemLocation VName
_ [DimSize]
destshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLocation
srclocation@(MemLocation VName
_ [DimSize]
srcshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just Shape (TExp Int64)
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just Shape (TExp Int64)
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
srcshape,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
destshape = do
      (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destlocation Shape (TExp Int64)
destis
      (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srclocation Shape (TExp Int64)
srcis
      Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
      Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
    | Bool
otherwise = do
      let destslice' :: [DimIndex (TExp Int64)]
destslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> TExp Int64) -> [DimSize] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
destshape) [DimIndex (TExp Int64)]
destslice
          srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> TExp Int64) -> [DimSize] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
srcshape) [DimIndex (TExp Int64)]
srcslice
          destrank :: Int
destrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice'
          srcrank :: Int
srcrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice'
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then
          String -> ImpM lore r op (Code op)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (Code op))
-> String -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
            String
"copyArrayDWIM: cannot copy to "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" from "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" because ranks do not match ("
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
destrank
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" vs "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
srcrank
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
        else
          if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex (TExp Int64)]
destslice' [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex (TExp Int64)]
srcslice'
            then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
            else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation [DimIndex (TExp Int64)]
destslice' MemLocation
srclocation [DimIndex (TExp Int64)]
srcslice'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op ()
copyDWIMDest :: ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe (Shape (TExp Int64))
Nothing ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"with slice destination."]
    Just Shape (TExp Int64)
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be written to memory destination."]
        ArrayDestination (Just MemLocation
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLocation
Nothing ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error String
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry lore
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry lore
_) ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [String] -> String
unwords [String
"copyDWIMDest: cannot write", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [String] -> String
unwords [String
"copyDWIMDest: source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
        String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"with slice", [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry lore
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
        String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: prim-typed target", VName -> String
forall a. Pretty a => a -> String
pretty VName
name, String
"with slice", [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
      | Just Shape (TExp Int64)
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arr) -> do
        let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
        (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
          MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
src_is
        Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
        String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords
            [ String
"copyDWIMDest: prim-typed target",
              VName -> String
forall a. Pretty a => a -> String
pretty VName
name,
              String
"and array-typed source",
              VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
              String
"with slice",
              [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
src_slice
            ]
    (ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLocation
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
      | Just Shape (TExp Int64)
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice -> do
        (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
        Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
        String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords
            [ String
"copyDWIMDest: array-typed target and prim-typed source",
              VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
              String
"with slice",
              [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
dest_slice
            ]
    (ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
      -- somewhere.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op ()
copyDWIM :: VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice DimSize
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry lore
dest_entry of
          ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun
          MemVar Maybe (Exp lore)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
  ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice DimSize
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM lore r op ()
copyDWIMFix :: VName
-> Shape (TExp Int64)
-> DimSize
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
dest Shape (TExp Int64)
dest_is DimSize
src Shape (TExp Int64)
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
dest_is) DimSize
src ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc ::
  Mem lore =>
  Pattern lore ->
  SubExp ->
  Space ->
  ImpM lore r op ()
compileAlloc :: Pattern lore -> DimSize -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetDec lore)
mem]) DimSize
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
e
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pattern lore
pat DimSize
_ Space
_ =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"compileAlloc: Invalid pattern: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT (MemBound NoUniqueness) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemBound NoUniqueness)
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an 'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
    Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf (PrimType -> ExpLeaf) -> PrimType -> ExpLeaf
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) PrimType
int64)
      TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((DimSize -> TExp Int64) -> [DimSize] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims Type
t))

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound ImpM lore r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> String -> IntType
forall a. HasCallStack => String -> a
error (String -> IntType) -> String -> IntType
forall a b. (a -> b) -> a -> b
$ String
"sFor': bound " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
bound String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
t
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
i TExp t
bound TExp t -> ImpM lore r op ()
body = do
  VName
i' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
i
  VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM lore r op ()
body (TExp t -> ImpM lore r op ()) -> TExp t -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
cond ImpM lore r op ()
body = do
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
s ImpM lore r op ()
code = do
  Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
  Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
  Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
cond ImpM lore r op ()
tbranch = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM lore r op ()
sOp :: op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ :: VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op VName
sAlloc :: String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc String
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- String -> Space -> ImpM lore r op VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space
  VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
bt ShapeBase DimSize
shape MemBind
membind = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
forall lore r op.
VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt ShapeBase DimSize
shape MemBind
membind
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem :: String
-> PrimType -> ShapeBase DimSize -> VName -> ImpM lore r op VName
sArrayInMem String
name PrimType
pt ShapeBase DimSize
shape VName
mem =
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
      Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> TPrimExp Int64 VName)
-> [DimSize] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (DimSize -> PrimExp VName) -> DimSize -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int64) ([DimSize] -> Shape (TPrimExp Int64 VName))
-> [DimSize] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int]
perm = do
  let permuted_dims :: [DimSize]
permuted_dims = [Int] -> [DimSize] -> [DimSize]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([DimSize] -> [DimSize]) -> [DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape
  VName
mem <- String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
forall lore r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc (String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase DimSize
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun
iota_ixfun = Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> TPrimExp Int64 VName)
-> [DimSize] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (DimSize -> PrimExp VName) -> DimSize -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int64) [DimSize]
permuted_dims
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: String
-> PrimType -> ShapeBase DimSize -> Space -> ImpM lore r op VName
sAllocArray String
name PrimType
pt ShapeBase DimSize
shape Space
space =
  String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
forall lore r op.
String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int
0 .. ShapeBase DimSize -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase DimSize
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: ShapeBase DimSize
shape = [DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> DimSize
intConst IntType
Int64 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- String -> ImpM lore r op VName
forall lore r op. String -> ImpM lore r op VName
newVNameForFun (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem"
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM lore r op ()
sWrite :: VName -> Shape (TExp Int64) -> Exp -> ImpM lore r op ()
sWrite VName
arr Shape (TExp Int64)
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr Shape (TExp Int64)
is
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM lore r op ()
sUpdate :: VName -> [DimIndex (TExp Int64)] -> DimSize -> ImpM lore r op ()
sUpdate VName
arr [DimIndex (TExp Int64)]
slice DimSize
v = VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [DimIndex (TExp Int64)]
slice DimSize
v []

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM lore r op ()) ->
  ImpM lore r op ()
sLoopNest :: ShapeBase DimSize
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = Shape (TExp Int64)
-> [DimSize]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a lore r op.
ToExp a =>
Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' [] ([DimSize]
 -> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ())
-> (ShapeBase DimSize -> [DimSize])
-> ShapeBase DimSize
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims
  where
    sLoopNest' :: Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' Shape (TExp Int64)
is [] Shape (TExp Int64) -> ImpM lore r op ()
f = Shape (TExp Int64) -> ImpM lore r op ()
f (Shape (TExp Int64) -> ImpM lore r op ())
-> Shape (TExp Int64) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a]
reverse Shape (TExp Int64)
is
    sLoopNest' Shape (TExp Int64)
is (a
d : [a]
ds) Shape (TExp Int64) -> ImpM lore r op ()
f =
      String
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. a -> [a] -> [a]
: Shape (TExp Int64)
is) [a]
ds Shape (TExp Int64) -> ImpM lore r op ()
f

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <~~ :: VName -> Exp -> ImpM lore r op ()
<~~ Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM lore r op ()
TV VName
x PrimType
_ <-- :: TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM lore r op () ->
  ImpM lore r op ()
function :: Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM lore r op ()
m = (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env lore r op -> Env lore r op
newFunction (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall lore r op. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM lore r op ()
m
  Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [Param]
outputs [Param]
inputs Code op
body [] []
  where
    addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env lore r op -> Env lore r op
newFunction Env lore r op
env = Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}