{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLoc (..),
    sliceMemLoc,
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,
    lookupAcc,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    copyElementWise,
    typeSize,
    inBounds,
    isMapTransposeCopy,

    -- * Constructing code.
    dLParams,
    dFParams,
    addLoopVar,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    dIndexSpace,
    dIndexSpace',
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    (<--),
    (<~~),
    function,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Data.String
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    bytes,
    elements,
    withElemType,
  )
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings
import Prelude hiding (quot)

-- | How to compile an t'Op'.
type OpCompiler rep r op = Pat rep -> Op rep -> ImpM rep r op ()

-- | How to compile some 'Stms'.
type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()

-- | How to compile an 'Exp'.
type ExpCompiler rep r op = Pat rep -> Exp rep -> ImpM rep r op ()

type CopyCompiler rep r op =
  PrimType ->
  MemLoc ->
  MemLoc ->
  ImpM rep r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op ()

data Operations rep r op = Operations
  { Operations rep r op -> ExpCompiler rep r op
opsExpCompiler :: ExpCompiler rep r op,
    Operations rep r op -> OpCompiler rep r op
opsOpCompiler :: OpCompiler rep r op,
    Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler :: StmsCompiler rep r op,
    Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler :: CopyCompiler rep r op,
    Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers :: M.Map Space (AllocCompiler rep r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem rep inner, FreeIn op) =>
  OpCompiler rep r op ->
  Operations rep r op
defaultOperations :: OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler rep r op
opc =
  Operations :: forall rep r op.
ExpCompiler rep r op
-> OpCompiler rep r op
-> StmsCompiler rep r op
-> CopyCompiler rep r op
-> Map Space (AllocCompiler rep r op)
-> Operations rep r op
Operations
    { opsExpCompiler :: ExpCompiler rep r op
opsExpCompiler = ExpCompiler rep r op
forall rep inner r op.
Mem rep inner =>
Pat rep -> Exp rep -> ImpM rep r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler rep r op
opsOpCompiler = OpCompiler rep r op
opc,
      opsStmsCompiler :: StmsCompiler rep r op
opsStmsCompiler = StmsCompiler rep r op
forall rep inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler rep r op
opsCopyCompiler = CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
defaultCopy,
      opsAllocCompilers :: Map Space (AllocCompiler rep r op)
opsAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is declared, this is where it is stored.
data MemLoc = MemLoc
  { MemLoc -> VName
memLocName :: VName,
    MemLoc -> [DimSize]
memLocShape :: [Imp.DimSize],
    MemLoc -> IxFun (TExp Int64)
memLocIxFun :: IxFun.IxFun (Imp.TExp Int64)
  }
  deriving (MemLoc -> MemLoc -> Bool
(MemLoc -> MemLoc -> Bool)
-> (MemLoc -> MemLoc -> Bool) -> Eq MemLoc
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLoc -> MemLoc -> Bool
$c/= :: MemLoc -> MemLoc -> Bool
== :: MemLoc -> MemLoc -> Bool
$c== :: MemLoc -> MemLoc -> Bool
Eq, Int -> MemLoc -> ShowS
[MemLoc] -> ShowS
MemLoc -> String
(Int -> MemLoc -> ShowS)
-> (MemLoc -> String) -> ([MemLoc] -> ShowS) -> Show MemLoc
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemLoc] -> ShowS
$cshowList :: [MemLoc] -> ShowS
show :: MemLoc -> String
$cshow :: MemLoc -> String
showsPrec :: Int -> MemLoc -> ShowS
$cshowsPrec :: Int -> MemLoc -> ShowS
Show)

sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc
sliceMemLoc :: MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc (MemLoc VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun) Slice (TExp Int64)
slice =
  VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [DimSize]
shape (IxFun (TExp Int64) -> MemLoc) -> IxFun (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Slice (TExp Int64) -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
ixfun Slice (TExp Int64)
slice

flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc
flatSliceMemLoc :: MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc (MemLoc VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun) FlatSlice (TExp Int64)
slice =
  VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [DimSize]
shape (IxFun (TExp Int64) -> MemLoc) -> IxFun (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> FlatSlice (TExp Int64) -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun (TExp Int64)
ixfun FlatSlice (TExp Int64)
slice

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLoc
entryArrayLoc :: MemLoc,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> String
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> String)
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> String
$cshow :: ArrayEntry -> String
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [DimSize]
entryArrayShape = MemLoc -> [DimSize]
memLocShape (MemLoc -> [DimSize])
-> (ArrayEntry -> MemLoc) -> ArrayEntry -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> String
(Int -> MemEntry -> ShowS)
-> (MemEntry -> String) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> String
$cshow :: MemEntry -> String
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> String
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> String)
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> String
$cshow :: ScalarEntry -> String
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry rep
  = ArrayVar (Maybe (Exp rep)) ArrayEntry
  | ScalarVar (Maybe (Exp rep)) ScalarEntry
  | MemVar (Maybe (Exp rep)) MemEntry
  | AccVar (Maybe (Exp rep)) (VName, Shape, [Type])
  deriving (Int -> VarEntry rep -> ShowS
[VarEntry rep] -> ShowS
VarEntry rep -> String
(Int -> VarEntry rep -> ShowS)
-> (VarEntry rep -> String)
-> ([VarEntry rep] -> ShowS)
-> Show (VarEntry rep)
forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
forall rep. RepTypes rep => [VarEntry rep] -> ShowS
forall rep. RepTypes rep => VarEntry rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [VarEntry rep] -> ShowS
show :: VarEntry rep -> String
$cshow :: forall rep. RepTypes rep => VarEntry rep -> String
showsPrec :: Int -> VarEntry rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLoc' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLoc)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> String
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> String)
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> String
$cshow :: ValueDestination -> String
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env rep r op = Env
  { Env rep r op -> ExpCompiler rep r op
envExpCompiler :: ExpCompiler rep r op,
    Env rep r op -> StmsCompiler rep r op
envStmsCompiler :: StmsCompiler rep r op,
    Env rep r op -> OpCompiler rep r op
envOpCompiler :: OpCompiler rep r op,
    Env rep r op -> CopyCompiler rep r op
envCopyCompiler :: CopyCompiler rep r op,
    Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers :: M.Map Space (AllocCompiler rep r op),
    Env rep r op -> Space
envDefaultSpace :: Imp.Space,
    Env rep r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    Env rep r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    Env rep r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    Env rep r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op
newEnv :: r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
ds =
  Env :: forall rep r op.
ExpCompiler rep r op
-> StmsCompiler rep r op
-> OpCompiler rep r op
-> CopyCompiler rep r op
-> Map Space (AllocCompiler rep r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env rep r op
Env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable rep = M.Map VName (VarEntry rep)

data ImpState rep r op = ImpState
  { ImpState rep r op -> VTable rep
stateVTable :: VTable rep,
    ImpState rep r op -> Functions op
stateFunctions :: Imp.Functions op,
    ImpState rep r op -> Code op
stateCode :: Imp.Code op,
    ImpState rep r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])),
    ImpState rep r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState rep r op
newState :: VNameSource -> ImpState rep r op
newState = VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
-> VNameSource
-> ImpState rep r op
forall rep r op.
VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
-> VNameSource
-> ImpState rep r op
ImpState VTable rep
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda rep, [DimSize]))
forall a. Monoid a => a
mempty

newtype ImpM rep r op a
  = ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a)
  deriving
    ( a -> ImpM rep r op b -> ImpM rep r op a
(a -> b) -> ImpM rep r op a -> ImpM rep r op b
(forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b. a -> ImpM rep r op b -> ImpM rep r op a)
-> Functor (ImpM rep r op)
forall a b. a -> ImpM rep r op b -> ImpM rep r op a
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ImpM rep r op b -> ImpM rep r op a
$c<$ :: forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
fmap :: (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$cfmap :: forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
Functor,
      Functor (ImpM rep r op)
a -> ImpM rep r op a
Functor (ImpM rep r op)
-> (forall a. a -> ImpM rep r op a)
-> (forall a b.
    ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a)
-> Applicative (ImpM rep r op)
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op. Functor (ImpM rep r op)
forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
$c<* :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
*> :: ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c*> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
liftA2 :: (a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
$cliftA2 :: forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
<*> :: ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$c<*> :: forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
pure :: a -> ImpM rep r op a
$cpure :: forall rep r op a. a -> ImpM rep r op a
$cp1Applicative :: forall rep r op. Functor (ImpM rep r op)
Applicative,
      Applicative (ImpM rep r op)
a -> ImpM rep r op a
Applicative (ImpM rep r op)
-> (forall a b.
    ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a. a -> ImpM rep r op a)
-> Monad (ImpM rep r op)
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall rep r op. Applicative (ImpM rep r op)
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ImpM rep r op a
$creturn :: forall rep r op a. a -> ImpM rep r op a
>> :: ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c>> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
>>= :: ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$c>>= :: forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$cp1Monad :: forall rep r op. Applicative (ImpM rep r op)
Monad,
      MonadState (ImpState rep r op),
      MonadReader (Env rep r op)
    )

instance MonadFreshNames (ImpM rep r op) where
  getNameSource :: ImpM rep r op VNameSource
getNameSource = (ImpState rep r op -> VNameSource) -> ImpM rep r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM rep r op ()
putNameSource VNameSource
src = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM rep r op) where
  askScope :: ImpM rep r op (Scope SOACS)
askScope = (ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS))
-> (ImpState rep r op -> Scope SOACS)
-> ImpM rep r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry rep -> NameInfo SOACS)
-> Map VName (VarEntry rep) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName (Type -> NameInfo SOACS)
-> (VarEntry rep -> Type) -> VarEntry rep -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry rep -> Type
forall rep. VarEntry rep -> Type
entryType) (Map VName (VarEntry rep) -> Scope SOACS)
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall rep r op. ImpState rep r op -> VTable rep
stateVTable
    where
      entryType :: VarEntry rep -> Type
entryType (MemVar Maybe (Exp rep)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp rep)
_ ArrayEntry
arrayEntry) =
        PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape ([DimSize] -> ShapeBase DimSize) -> [DimSize] -> ShapeBase DimSize
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp rep)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp rep)
_ (VName
acc, ShapeBase DimSize
ispace, [Type]
ts)) =
        VName -> ShapeBase DimSize -> [Type] -> NoUniqueness -> Type
forall shape u.
VName -> ShapeBase DimSize -> [Type] -> u -> TypeBase shape u
Acc VName
acc ShapeBase DimSize
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM rep r op a ->
  r ->
  Operations rep r op ->
  Imp.Space ->
  ImpState rep r op ->
  (a, ImpState rep r op)
runImpM :: ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (ImpM ReaderT (Env rep r op) (State (ImpState rep r op)) a
m) r
r Operations rep r op
ops Space
space = State (ImpState rep r op) a
-> ImpState rep r op -> (a, ImpState rep r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r op) (State (ImpState rep r op)) a
-> Env rep r op -> State (ImpState rep r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r op) (State (ImpState rep r op)) a
m (Env rep r op -> State (ImpState rep r op) a)
-> Env rep r op -> State (ImpState rep r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations rep r op -> Space -> Env rep r op
forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (Imp.Code op')
subImpM_ :: r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ r'
r Operations rep r' op'
ops ImpM rep r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM rep r op (a, Code op') -> ImpM rep r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops ImpM rep r' op' a
m

subImpM ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (a, Imp.Code op')
subImpM :: r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops (ImpM ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m) = do
  Env rep r op
env <- ImpM rep r op (Env rep r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState rep r op
s <- ImpM rep r op (ImpState rep r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env rep r' op'
env' =
        Env rep r op
env
          { envExpCompiler :: ExpCompiler rep r' op'
envExpCompiler = Operations rep r' op' -> ExpCompiler rep r' op'
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r' op'
ops,
            envStmsCompiler :: StmsCompiler rep r' op'
envStmsCompiler = Operations rep r' op' -> StmsCompiler rep r' op'
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r' op'
ops,
            envCopyCompiler :: CopyCompiler rep r' op'
envCopyCompiler = Operations rep r' op' -> CopyCompiler rep r' op'
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r' op'
ops,
            envOpCompiler :: OpCompiler rep r' op'
envOpCompiler = Operations rep r' op' -> OpCompiler rep r' op'
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler rep r' op')
envAllocCompilers = Operations rep r' op' -> Map Space (AllocCompiler rep r' op')
forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState rep r' op'
s' =
        ImpState :: forall rep r op.
VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
-> VNameSource
-> ImpState rep r op
ImpState
          { stateVTable :: VTable rep
stateVTable = ImpState rep r op -> VTable rep
forall rep r op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda rep, [DimSize]))
stateAccs = ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
stateAccs ImpState rep r op
s
          }
      (a
x, ImpState rep r' op'
s'') = State (ImpState rep r' op') a
-> ImpState rep r' op' -> (a, ImpState rep r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
-> Env rep r' op' -> State (ImpState rep r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m Env rep r' op'
env') ImpState rep r' op'
s'

  VNameSource -> ImpM rep r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM rep r op ())
-> VNameSource -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ImpState rep r' op' -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r' op'
s''
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ImpState rep r' op' -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r' op'
s''
  (a, Code op') -> ImpM rep r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState rep r' op' -> Code op'
forall rep r op. ImpState rep r op -> Code op
stateCode ImpState rep r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op)
collect :: ImpM rep r op () -> ImpM rep r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM rep r op ((), Code op) -> ImpM rep r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM rep r op ((), Code op) -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op ((), Code op))
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM rep r op () -> ImpM rep r op ((), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect'

collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op)
collect' :: ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op a
m = do
  Code op
prev_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall rep r op. ImpState rep r op -> Code op
stateCode
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
  a
x <- ImpM rep r op a
m
  Code op
new_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall rep r op. ImpState rep r op -> Code op
stateCode
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  (a, Code op) -> ImpM rep r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM rep r op () -> ImpM rep r op ()
comment :: String -> ImpM rep r op () -> ImpM rep r op ()
comment String
desc ImpM rep r op ()
m = do
  Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM rep r op ()
emit :: Code op -> ImpM rep r op ()
emit Code op
code = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = ImpState rep r op -> Code op
forall rep r op. ImpState rep r op -> Code op
stateCode ImpState rep r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM rep r op ()
warnings :: Warnings -> ImpM rep r op ()
warnings Warnings
ws = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState rep r op -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> String -> ImpM rep r op ()
warn :: loc -> [loc] -> String -> ImpM rep r op ()
warn loc
loc [loc]
locs String
problem =
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) (String -> Doc
forall a. IsString a => String -> a
fromString String
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM rep r op ()
emitFunction :: Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState rep r op -> Functions op) -> ImpM rep r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM rep r op Bool
hasFunction :: Name -> ImpM rep r op Bool
hasFunction Name
fname = (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Bool) -> ImpM rep r op Bool)
-> (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem rep inner => Stms rep -> VTable rep
constsVTable :: Stms rep -> VTable rep
constsVTable = (Stm rep -> VTable rep) -> Stms rep -> VTable rep
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> VTable rep
forall rep.
HasLetDecMem (LetDec rep) =>
Stm rep -> Map VName (VarEntry rep)
stmVtable
  where
    stmVtable :: Stm rep -> Map VName (VarEntry rep)
stmVtable (Let Pat rep
pat StmAux (ExpDec rep)
_ Exp rep
e) =
      (PatElemT (LetDec rep) -> Map VName (VarEntry rep))
-> [PatElemT (LetDec rep)] -> Map VName (VarEntry rep)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp rep -> PatElemT (LetDec rep) -> Map VName (VarEntry rep)
forall t rep.
HasLetDecMem t =>
Exp rep -> PatElemT t -> Map VName (VarEntry rep)
peVtable Exp rep
e) ([PatElemT (LetDec rep)] -> Map VName (VarEntry rep))
-> [PatElemT (LetDec rep)] -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat
    peVtable :: Exp rep -> PatElemT t -> Map VName (VarEntry rep)
peVtable Exp rep
e (PatElem VName
name t
dec) =
      VName -> VarEntry rep -> Map VName (VarEntry rep)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry rep -> Map VName (VarEntry rep))
-> VarEntry rep -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
forall rep.
Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
memBoundToVarEntry (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (MemBound NoUniqueness -> VarEntry rep)
-> MemBound NoUniqueness -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ t -> MemBound NoUniqueness
forall t. HasLetDecMem t => t -> MemBound NoUniqueness
letDecMem t
dec

compileProg ::
  (Mem rep inner, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations rep r op ->
  Imp.Space ->
  Prog rep ->
  m (Warnings, Imp.Definitions op)
compileProg :: r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
compileProg r
r Operations rep r op
ops Space
space (Prog Stms rep
consts [FunDef rep]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState rep r op]
ss) =
          [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState rep r op)] -> ([()], [ImpState rep r op]))
-> [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState rep r op)
-> (FunDef rep -> ((), ImpState rep r op))
-> [FunDef rep]
-> [((), ImpState rep r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState rep r op)
forall a. Strategy a
rpar (VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src) [FunDef rep]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
        (Constants op
consts', ImpState rep r op
s') =
          ImpM rep r op (Constants op)
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (Constants op, ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (Names -> Stms rep -> ImpM rep r op (Constants op)
forall rep r op. Names -> Stms rep -> ImpM rep r op (Constants op)
compileConsts Names
free_in_funs Stms rep
consts) r
r Operations rep r op
ops Space
space (ImpState rep r op -> (Constants op, ImpState rep r op))
-> ImpState rep r op -> (Constants op, ImpState rep r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState rep r op] -> ImpState rep r op
forall rep r op rep r. [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss
     in ( ( ImpState rep r op -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s',
            Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s')
          ),
          ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src FunDef rep
fdef =
      ImpM rep r op ()
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> ((), ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM
        (FunDef rep -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
FunDef rep -> ImpM rep r op ()
compileFunDef FunDef rep
fdef)
        r
r
        Operations rep r op
ops
        Space
space
        (VNameSource -> ImpState rep Any op
forall rep r op. VNameSource -> ImpState rep r op
newState VNameSource
src) {stateVTable :: VTable rep
stateVTable = Stms rep -> VTable rep
forall rep inner. Mem rep inner => Stms rep -> VTable rep
constsVTable Stms rep
consts}

    combineStates :: [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState rep r op -> VNameSource)
-> [ImpState rep r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource [ImpState rep r op]
ss)
       in (VNameSource -> ImpState rep Any op
forall rep r op. VNameSource -> ImpState rep r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Warnings)
-> [ImpState rep r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings [ImpState rep r op]
ss
            }

compileConsts :: Names -> Stms rep -> ImpM rep r op (Imp.Constants op)
compileConsts :: Names -> Stms rep -> ImpM rep r op (Constants op)
compileConsts Names
used_consts Stms rep
stms = do
  Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
used_consts Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM rep r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM rep r op (Constants op))
-> Constants op -> ImpM rep r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where
    -- Fish out those top-level declarations in the constant
    -- initialisation code that are free in the functions.
    extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
      Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
    extract (Imp.DeclareMem VName
name Space
space)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
          Code op
forall a. Monoid a => a
mempty
        )
    extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
          Code op
forall a. Monoid a => a
mempty
        )
    extract Code op
s =
      (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam ::
  Mem rep inner =>
  FParam rep ->
  ImpM rep r op (Either Imp.Param ArrayDecl)
compileInParam :: FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam FParam rep
fparam = case Param (MemInfo DimSize Uniqueness MemBind)
-> MemInfo DimSize Uniqueness MemBind
forall dec. Param dec -> dec
paramDec FParam rep
Param (MemInfo DimSize Uniqueness MemBind)
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt ShapeBase DimSize
shape Uniqueness
_ (ArrayIn VName
mem IxFun (TExp Int64)
ixfun) ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLoc -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLoc -> ArrayDecl) -> MemLoc -> ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) IxFun (TExp Int64)
ixfun
  MemAcc {} ->
    String -> ImpM rep r op (Either Param ArrayDecl)
forall a. HasCallStack => String -> a
error String
"Functions may not have accumulator parameters."
  where
    name :: VName
name = Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName FParam rep
Param (MemInfo DimSize Uniqueness MemBind)
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLoc

compileInParams ::
  Mem rep inner =>
  [FParam rep] ->
  [EntryParam] ->
  ImpM rep r op ([Imp.Param], [ArrayDecl], [(Name, Imp.ExternalValue)])
compileInParams :: [FParam rep]
-> [EntryParam]
-> ImpM rep r op ([Param], [ArrayDecl], [(Name, ExternalValue)])
compileInParams [FParam rep]
params [EntryParam]
eparams = do
  let ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params, [Param (MemInfo DimSize Uniqueness MemBind)]
val_params) =
        Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param (MemInfo DimSize Uniqueness MemBind)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryParam -> Int) -> [EntryParam] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (EntryPointType -> Int
entryPointSize (EntryPointType -> Int)
-> (EntryParam -> EntryPointType) -> EntryParam -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryParam -> EntryPointType
entryParamType) [EntryParam]
eparams)) [FParam rep]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM rep r op [Either Param ArrayDecl]
-> ImpM rep r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (MemInfo DimSize Uniqueness MemBind)
 -> ImpM rep r op (Either Param ArrayDecl))
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM rep r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo DimSize Uniqueness MemBind)
-> ImpM rep r op (Either Param ArrayDecl)
forall rep inner r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind)
 -> Maybe (VName, Space))
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param (MemInfo DimSize Uniqueness MemBind) -> Maybe (VName, Space)
forall d u ret. Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam rep]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
              (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
              Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam, Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLoc VName
mem [DimSize]
shape IxFun (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [DimSize]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryParam]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [(Name, ExternalValue)]
mkExts (EntryParam Name
v (TypeOpaque Uniqueness
u String
desc Int
n) : [EntryParam]
epts) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams =
        let ([Param (MemInfo DimSize Uniqueness MemBind)]
fparams', [Param (MemInfo DimSize Uniqueness MemBind)]
rest) = Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
         in ( Name
v,
              Uniqueness -> String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
                Uniqueness
u
                String
desc
                ((Param (MemInfo DimSize Uniqueness MemBind) -> Maybe ValueDesc)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams')
            ) (Name, ExternalValue)
-> [(Name, ExternalValue)] -> [(Name, ExternalValue)]
forall a. a -> [a] -> [a]
:
            [EntryParam]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [(Name, ExternalValue)]
mkExts [EntryParam]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
rest
      mkExts (EntryParam Name
v (TypeUnsigned Uniqueness
u) : [EntryParam]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam : [Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe (Name, ExternalValue) -> [(Name, ExternalValue)]
forall a. Maybe a -> [a]
maybeToList ((Name
v,) (ExternalValue -> (Name, ExternalValue))
-> (ValueDesc -> ExternalValue)
-> ValueDesc
-> (Name, ExternalValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u (ValueDesc -> (Name, ExternalValue))
-> Maybe ValueDesc -> Maybe (Name, ExternalValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeUnsigned)
          [(Name, ExternalValue)]
-> [(Name, ExternalValue)] -> [(Name, ExternalValue)]
forall a. [a] -> [a] -> [a]
++ [EntryParam]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [(Name, ExternalValue)]
mkExts [EntryParam]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts (EntryParam Name
v (TypeDirect Uniqueness
u) : [EntryParam]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam : [Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe (Name, ExternalValue) -> [(Name, ExternalValue)]
forall a. Maybe a -> [a]
maybeToList ((Name
v,) (ExternalValue -> (Name, ExternalValue))
-> (ValueDesc -> ExternalValue)
-> ValueDesc
-> (Name, ExternalValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u (ValueDesc -> (Name, ExternalValue))
-> Maybe ValueDesc -> Maybe (Name, ExternalValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeDirect)
          [(Name, ExternalValue)]
-> [(Name, ExternalValue)] -> [(Name, ExternalValue)]
forall a. [a] -> [a] -> [a]
++ [EntryParam]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [(Name, ExternalValue)]
mkExts [EntryParam]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts [EntryParam]
_ [Param (MemInfo DimSize Uniqueness MemBind)]
_ = []

  ([Param], [ArrayDecl], [(Name, ExternalValue)])
-> ImpM rep r op ([Param], [ArrayDecl], [(Name, ExternalValue)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryParam]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [(Name, ExternalValue)]
mkExts [EntryParam]
eparams [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLoc
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParam ::
  FunReturns -> ImpM rep r op (Maybe Imp.Param, ValueDestination)
compileOutParam :: FunReturns -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam (MemPrim PrimType
t) = do
  VName
name <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"prim_out"
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param -> Maybe Param
forall a. a -> Maybe a
Just (Param -> Maybe Param) -> Param -> Maybe Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t, VName -> ValueDestination
ScalarDestination VName
name)
compileOutParam (MemMem Space
space) = do
  VName
name <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem_out"
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param -> Maybe Param
forall a. a -> Maybe a
Just (Param -> Maybe Param) -> Param -> Maybe Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space, VName -> ValueDestination
MemoryDestination VName
name)
compileOutParam MemArray {} =
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Param
forall a. Maybe a
Nothing, Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing)
compileOutParam MemAcc {} =
  String -> ImpM rep r op (Maybe Param, ValueDestination)
forall a. HasCallStack => String -> a
error String
"Functions may not return accumulators."

compileExternalValues ::
  Mem rep inner =>
  [RetType rep] ->
  [EntryPointType] ->
  [Maybe Imp.Param] ->
  ImpM rep r op [Imp.ExternalValue]
compileExternalValues :: [RetType rep]
-> [EntryPointType]
-> [Maybe Param]
-> ImpM rep r op [ExternalValue]
compileExternalValues [RetType rep]
orig_rts [EntryPointType]
orig_epts [Maybe Param]
maybe_params = do
  let ([FunReturns]
ctx_rts, [FunReturns]
val_rts) =
        Int -> [FunReturns] -> ([FunReturns], [FunReturns])
forall a. Int -> [a] -> ([a], [a])
splitAt ([FunReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType rep]
[FunReturns]
orig_rts Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [RetType rep]
[FunReturns]
orig_rts

  let nthOut :: Int -> VName
nthOut Int
i = case Int -> [Maybe Param] -> Maybe (Maybe Param)
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [Maybe Param]
maybe_params of
        Just (Just Param
p) -> Param -> VName
Imp.paramName Param
p
        Just Maybe Param
Nothing -> String -> VName
forall a. HasCallStack => String -> a
error (String -> VName) -> String -> VName
forall a b. (a -> b) -> a -> b
$ String
"Output " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" not a param."
        Maybe (Maybe Param)
Nothing -> String -> VName
forall a. HasCallStack => String -> a
error (String -> VName) -> String -> VName
forall a b. (a -> b) -> a -> b
$ String
"Param " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" does not exist."

      mkValueDesc :: Int -> Signedness -> FunReturns -> ImpM rep r op ValueDesc
mkValueDesc Int
_ Signedness
signedness (MemArray PrimType
t ShapeBase (Ext DimSize)
shape Uniqueness
_ MemReturn
ret) = do
        (VName
mem, Space
space) <-
          case MemReturn
ret of
            ReturnsNewBlock Space
space Int
j ExtIxFun
_ixfun ->
              (VName, Space) -> ImpM rep r op (VName, Space)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> VName
nthOut Int
j, Space
space)
            ReturnsInBlock VName
mem ExtIxFun
_ixfun -> do
              Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
              (VName, Space) -> ImpM rep r op (VName, Space)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, Space
space)
        ValueDesc -> ImpM rep r op ValueDesc
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDesc -> ImpM rep r op ValueDesc)
-> ValueDesc -> ImpM rep r op ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
space PrimType
t Signedness
signedness ([DimSize] -> ValueDesc) -> [DimSize] -> ValueDesc
forall a b. (a -> b) -> a -> b
$ (Ext DimSize -> DimSize) -> [Ext DimSize] -> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map Ext DimSize -> DimSize
f ([Ext DimSize] -> [DimSize]) -> [Ext DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext DimSize) -> [Ext DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext DimSize)
shape
        where
          f :: Ext DimSize -> DimSize
f (Free DimSize
v) = DimSize
v
          f (Ext Int
i) = VName -> DimSize
Var (VName -> DimSize) -> VName -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
i Signedness
signedness (MemPrim PrimType
bt) =
        ValueDesc -> ImpM rep r op ValueDesc
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDesc -> ImpM rep r op ValueDesc)
-> ValueDesc -> ImpM rep r op ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
_ Signedness
_ MemAcc {} =
        String -> ImpM rep r op ValueDesc
forall a. HasCallStack => String -> a
error String
"mkValueDesc: unexpected MemAcc output."
      mkValueDesc Int
_ Signedness
_ MemMem {} =
        String -> ImpM rep r op ValueDesc
forall a. HasCallStack => String -> a
error String
"mkValueDesc: unexpected MemMem output."

      mkExts :: Int
-> [EntryPointType]
-> [FunReturns]
-> ImpM rep r op [ExternalValue]
mkExts Int
i (TypeOpaque Uniqueness
u String
desc Int
n : [EntryPointType]
epts) [FunReturns]
rets = do
        let ([FunReturns]
rets', [FunReturns]
rest) = Int -> [FunReturns] -> ([FunReturns], [FunReturns])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [FunReturns]
rets
        [ValueDesc]
vds <- (Int -> FunReturns -> ImpM rep r op ValueDesc)
-> [Int] -> [FunReturns] -> ImpM rep r op [ValueDesc]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Int -> Signedness -> FunReturns -> ImpM rep r op ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Int
i ..] [FunReturns]
rets'
        (Uniqueness -> String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue Uniqueness
u String
desc [ValueDesc]
vds ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:) ([ExternalValue] -> [ExternalValue])
-> ImpM rep r op [ExternalValue] -> ImpM rep r op [ExternalValue]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryPointType]
-> [FunReturns]
-> ImpM rep r op [ExternalValue]
mkExts (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) [EntryPointType]
epts [FunReturns]
rest
      mkExts Int
i (TypeUnsigned Uniqueness
u : [EntryPointType]
epts) (FunReturns
ret : [FunReturns]
rets) = do
        ValueDesc
vd <- Int -> Signedness -> FunReturns -> ImpM rep r op ValueDesc
mkValueDesc Int
i Signedness
Imp.TypeUnsigned FunReturns
ret
        (Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u ValueDesc
vd ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:) ([ExternalValue] -> [ExternalValue])
-> ImpM rep r op [ExternalValue] -> ImpM rep r op [ExternalValue]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryPointType]
-> [FunReturns]
-> ImpM rep r op [ExternalValue]
mkExts (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [EntryPointType]
epts [FunReturns]
rets
      mkExts Int
i (TypeDirect Uniqueness
u : [EntryPointType]
epts) (FunReturns
ret : [FunReturns]
rets) = do
        ValueDesc
vd <- Int -> Signedness -> FunReturns -> ImpM rep r op ValueDesc
mkValueDesc Int
i Signedness
Imp.TypeDirect FunReturns
ret
        (Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u ValueDesc
vd ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:) ([ExternalValue] -> [ExternalValue])
-> ImpM rep r op [ExternalValue] -> ImpM rep r op [ExternalValue]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryPointType]
-> [FunReturns]
-> ImpM rep r op [ExternalValue]
mkExts (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [EntryPointType]
epts [FunReturns]
rets
      mkExts Int
_ [EntryPointType]
_ [FunReturns]
_ = [ExternalValue] -> ImpM rep r op [ExternalValue]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

  Int
-> [EntryPointType]
-> [FunReturns]
-> ImpM rep r op [ExternalValue]
mkExts ([FunReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FunReturns]
ctx_rts) [EntryPointType]
orig_epts [FunReturns]
val_rts

compileOutParams ::
  Mem rep inner =>
  [RetType rep] ->
  Maybe [EntryPointType] ->
  ImpM rep r op ([Imp.ExternalValue], [Imp.Param], [ValueDestination])
compileOutParams :: [RetType rep]
-> Maybe [EntryPointType]
-> ImpM rep r op ([ExternalValue], [Param], [ValueDestination])
compileOutParams [RetType rep]
orig_rts Maybe [EntryPointType]
maybe_orig_epts = do
  ([Maybe Param]
maybe_params, [ValueDestination]
dests) <- [(Maybe Param, ValueDestination)]
-> ([Maybe Param], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe Param, ValueDestination)]
 -> ([Maybe Param], [ValueDestination]))
-> ImpM rep r op [(Maybe Param, ValueDestination)]
-> ImpM rep r op ([Maybe Param], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunReturns -> ImpM rep r op (Maybe Param, ValueDestination))
-> [FunReturns] -> ImpM rep r op [(Maybe Param, ValueDestination)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunReturns -> ImpM rep r op (Maybe Param, ValueDestination)
forall rep r op.
FunReturns -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam [RetType rep]
[FunReturns]
orig_rts
  [ExternalValue]
evs <- case Maybe [EntryPointType]
maybe_orig_epts of
    Just [EntryPointType]
orig_epts -> [RetType rep]
-> [EntryPointType]
-> [Maybe Param]
-> ImpM rep r op [ExternalValue]
forall rep inner r op.
Mem rep inner =>
[RetType rep]
-> [EntryPointType]
-> [Maybe Param]
-> ImpM rep r op [ExternalValue]
compileExternalValues [RetType rep]
orig_rts [EntryPointType]
orig_epts [Maybe Param]
maybe_params
    Maybe [EntryPointType]
Nothing -> [ExternalValue] -> ImpM rep r op [ExternalValue]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
  ([ExternalValue], [Param], [ValueDestination])
-> ImpM rep r op ([ExternalValue], [Param], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
evs, [Maybe Param] -> [Param]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
maybe_params, [ValueDestination]
dests)

compileFunDef ::
  Mem rep inner =>
  FunDef rep ->
  ImpM rep r op ()
compileFunDef :: FunDef rep -> ImpM rep r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType rep]
rettype [FParam rep]
params BodyT rep
body) =
  (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envFunction :: Maybe Name
envFunction = Maybe Name
name_entry Maybe Name -> Maybe Name -> Maybe Name
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [(Name, ExternalValue)]
args), Code op
body') <- ImpM
  rep
  r
  op
  ([Param], [Param], [ExternalValue], [(Name, ExternalValue)])
-> ImpM
     rep
     r
     op
     (([Param], [Param], [ExternalValue], [(Name, ExternalValue)]),
      Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM
  rep
  r
  op
  ([Param], [Param], [ExternalValue], [(Name, ExternalValue)])
compile
    Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [(Name, ExternalValue)]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [(Name, ExternalValue)]
-> FunctionT a
Imp.Function Maybe Name
name_entry [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [(Name, ExternalValue)]
args
  where
    (Maybe Name
name_entry, [EntryParam]
params_entry, Maybe [EntryPointType]
ret_entry) = case Maybe EntryPoint
entry of
      Maybe EntryPoint
Nothing ->
        ( Maybe Name
forall a. Maybe a
Nothing,
          Int -> EntryParam -> [EntryParam]
forall a. Int -> a -> [a]
replicate ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param (MemInfo DimSize Uniqueness MemBind)]
params) (Name -> EntryPointType -> EntryParam
EntryParam Name
"" (EntryPointType -> EntryParam) -> EntryPointType -> EntryParam
forall a b. (a -> b) -> a -> b
$ Uniqueness -> EntryPointType
TypeDirect Uniqueness
forall a. Monoid a => a
mempty),
          Maybe [EntryPointType]
forall a. Maybe a
Nothing
        )
      Just (Name
x, [EntryParam]
y, [EntryPointType]
z) -> (Name -> Maybe Name
forall a. a -> Maybe a
Just Name
x, [EntryParam]
y, [EntryPointType] -> Maybe [EntryPointType]
forall a. a -> Maybe a
Just [EntryPointType]
z)
    compile :: ImpM
  rep
  r
  op
  ([Param], [Param], [ExternalValue], [(Name, ExternalValue)])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, [(Name, ExternalValue)]
args) <- [FParam rep]
-> [EntryParam]
-> ImpM rep r op ([Param], [ArrayDecl], [(Name, ExternalValue)])
forall rep inner r op.
Mem rep inner =>
[FParam rep]
-> [EntryParam]
-> ImpM rep r op ([Param], [ArrayDecl], [(Name, ExternalValue)])
compileInParams [FParam rep]
params [EntryParam]
params_entry
      ([ExternalValue]
results, [Param]
outparams, [ValueDestination]
dests) <- [RetType rep]
-> Maybe [EntryPointType]
-> ImpM rep r op ([ExternalValue], [Param], [ValueDestination])
forall rep inner r op.
Mem rep inner =>
[RetType rep]
-> Maybe [EntryPointType]
-> ImpM rep r op ([ExternalValue], [Param], [ValueDestination])
compileOutParams [RetType rep]
rettype Maybe [EntryPointType]
ret_entry
      [FParam rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams [FParam rep]
params
      [ArrayDecl] -> ImpM rep r op ()
forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec rep
_ Stms rep
stms Result
ses = BodyT rep
body
      Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [(ValueDestination, SubExpRes)]
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> Result -> [(ValueDestination, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) (((ValueDestination, SubExpRes) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExpRes Certs
_ DimSize
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

      ([Param], [Param], [ExternalValue], [(Name, ExternalValue)])
-> ImpM
     rep
     r
     op
     ([Param], [Param], [ExternalValue], [(Name, ExternalValue)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [(Name, ExternalValue)]
args)

compileBody :: Pat rep -> Body rep -> ImpM rep r op ()
compileBody :: Pat rep -> Body rep -> ImpM rep r op ()
compileBody Pat rep
pat (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  [ValueDestination]
dests <- Pat rep -> ImpM rep r op [ValueDestination]
forall rep r op. Pat rep -> ImpM rep r op [ValueDestination]
destinationFromPat Pat rep
pat
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, SubExpRes)]
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> Result -> [(ValueDestination, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) (((ValueDestination, SubExpRes) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExpRes Certs
_ DimSize
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param dec]
params (Body BodyDec rep
_ Stms rep
stms Result
ses) =
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, SubExpRes)]
-> ((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> Result -> [(Param dec, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params Result
ses) (((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, SubExpRes Certs
_ DimSize
se) -> VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] DimSize
se []

compileLoopBody :: Typed dec => [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody :: [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM rep r op VName)
-> [Param dec] -> ImpM rep r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM rep r op VName)
-> (Param dec -> String) -> Param dec -> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_tmp") ShowS -> (Param dec -> String) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString (VName -> String) -> (Param dec -> VName) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM rep r op ()]
copy_to_merge_params <- [(Param dec, VName, SubExpRes)]
-> ((Param dec, VName, SubExpRes)
    -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op [ImpM rep r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec] -> [VName] -> Result -> [(Param dec, VName, SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames Result
ses) (((Param dec, VName, SubExpRes)
  -> ImpM rep r op (ImpM rep r op ()))
 -> ImpM rep r op [ImpM rep r op ()])
-> ((Param dec, VName, SubExpRes)
    -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op [ImpM rep r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExpRes Certs
_ DimSize
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- DimSize
se -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM rep r op ()] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM rep r op ()]
copy_to_merge_params

compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m = do
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
cb <- (Env rep r op
 -> Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ())
-> ImpM
     rep
     r
     op
     (Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op
-> Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
cb Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m

defCompileStms ::
  (Mem rep inner, FreeIn op) =>
  Names ->
  Stms rep ->
  ImpM rep r op () ->
  ImpM rep r op ()
defCompileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM rep r op Names -> ImpM rep r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM rep r op Names -> ImpM rep r op ())
-> ImpM rep r op Names -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm rep] -> ImpM rep r op Names)
-> [Stm rep] -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
allocs (Let Pat rep
pat StmAux (ExpDec rep)
aux Exp rep
e : [Stm rep]
bs) = do
      Maybe (Exp rep) -> [PatElem rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem rep] -> ImpM rep r op ()
dVars (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (Pat rep -> [PatElem rep]
forall dec. PatT dec -> [PatElemT dec]
patElems Pat rep
pat)

      Code op
e_code <-
        Attrs -> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs (StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux) (ImpM rep r op (Code op) -> ImpM rep r op (Code op))
-> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
          ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pat rep -> Exp rep -> ImpM rep r op ()
forall rep r op. Pat rep -> Exp rep -> ImpM rep r op ()
compileExp Pat rep
pat Exp rep
e
      (Names
live_after, Code op
bs_code) <- ImpM rep r op Names -> ImpM rep r op (Names, Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' (ImpM rep r op Names -> ImpM rep r op (Names, Code op))
-> ImpM rep r op Names -> ImpM rep r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' (Pat rep -> Set (VName, Space)
patternAllocs Pat rep
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm rep]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
              Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
          to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
e_code
      ((VName, Space) -> ImpM rep r op ())
-> Set (VName, Space) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
bs_code

      Names -> ImpM rep r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM rep r op Names) -> Names -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code
      Names -> ImpM rep r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM rep r op Names) -> Names -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: Pat rep -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (Pat rep -> [(VName, Space)]) -> Pat rep -> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem rep -> Maybe (VName, Space))
-> [PatElem rep] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElem rep -> Maybe (VName, Space)
forall dec. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElem rep] -> [(VName, Space)])
-> (Pat rep -> [PatElem rep]) -> Pat rep -> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat rep -> [PatElem rep]
forall dec. PatT dec -> [PatElemT dec]
patElems
    isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pat rep -> Exp rep -> ImpM rep r op ()
compileExp :: Pat rep -> Exp rep -> ImpM rep r op ()
compileExp Pat rep
pat Exp rep
e = do
  Pat rep -> Exp rep -> ImpM rep r op ()
ec <- (Env rep r op -> Pat rep -> Exp rep -> ImpM rep r op ())
-> ImpM rep r op (Pat rep -> Exp rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Pat rep -> Exp rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler
  Pat rep -> Exp rep -> ImpM rep r op ()
ec Pat rep
pat Exp rep
e

defCompileExp ::
  (Mem rep inner) =>
  Pat rep ->
  Exp rep ->
  ImpM rep r op ()
defCompileExp :: Pat rep -> Exp rep -> ImpM rep r op ()
defCompileExp Pat rep
pat (If DimSize
cond BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_) =
  TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (DimSize -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp DimSize
cond) (Pat rep -> BodyT rep -> ImpM rep r op ()
forall rep r op. Pat rep -> Body rep -> ImpM rep r op ()
compileBody Pat rep
pat BodyT rep
tbranch) (Pat rep -> BodyT rep -> ImpM rep r op ()
forall rep r op. Pat rep -> Body rep -> ImpM rep r op ()
compileBody Pat rep
pat BodyT rep
fbranch)
defCompileExp Pat rep
pat (Apply Name
fname [(DimSize, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  [ValueDestination]
dest <- Pat rep -> ImpM rep r op [ValueDestination]
forall rep r op. Pat rep -> ImpM rep r op [ValueDestination]
destinationFromPat Pat rep
pat
  [VName]
targets <- [ValueDestination] -> ImpM rep r op [VName]
forall rep r op. [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM rep r op [Maybe Arg] -> ImpM rep r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimSize, Diet) -> ImpM rep r op (Maybe Arg))
-> [(DimSize, Diet)] -> ImpM rep r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimSize, Diet) -> ImpM rep r op (Maybe Arg)
forall (m :: * -> *) t b.
(Monad m, HasScope t m) =>
(DimSize, b) -> m (Maybe Arg)
compileArg [(DimSize, Diet)]
args
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (DimSize, b) -> m (Maybe Arg)
compileArg (DimSize
se, b
_) = do
      Type
t <- DimSize -> m Type
forall t (m :: * -> *). HasScope t m => DimSize -> m Type
subExpType DimSize
se
      case (DimSize
se, Type
t) of
        (DimSize
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (DimSize, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pat rep
pat (BasicOp BasicOp
op) = Pat rep -> BasicOp -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Pat rep -> BasicOp -> ImpM rep r op ()
defCompileBasicOp Pat rep
pat BasicOp
op
defCompileExp Pat rep
pat (DoLoop [(FParam rep, DimSize)]
merge LoopForm rep
form BodyT rep
body) = do
  Attrs
attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> String -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> String -> ImpM rep r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] String
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  [FParam rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams [FParam rep]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
  [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge (((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo DimSize Uniqueness MemBind)
p, DimSize
se) ->
    Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
p) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
p) [] DimSize
se []

  let doBody :: ImpM rep r op ()
doBody = [Param (MemInfo DimSize Uniqueness MemBind)]
-> BodyT rep -> ImpM rep r op ()
forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param (MemInfo DimSize Uniqueness MemBind)]
params BodyT rep
body

  case LoopForm rep
form of
    ForLoop VName
i IntType
_ DimSize
bound [(LParam rep, VName)]
loopvars -> do
      let setLoopParam :: (Param (MemBound NoUniqueness), VName) -> ImpM rep r op ()
setLoopParam (Param (MemBound NoUniqueness)
p, VName
a)
            | Prim PrimType
_ <- Param (MemBound NoUniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemBound NoUniqueness)
p =
              VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param (MemBound NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (MemBound NoUniqueness)
p) [] (VName -> DimSize
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
i]
            | Bool
otherwise =
              () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      Exp
bound' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
bound

      [LParam rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ((Param (MemBound NoUniqueness), VName)
 -> Param (MemBound NoUniqueness))
-> [(Param (MemBound NoUniqueness), VName)]
-> [Param (MemBound NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemBound NoUniqueness), VName)
-> Param (MemBound NoUniqueness)
forall a b. (a, b) -> a
fst [(LParam rep, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars
      VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound' (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param (MemBound NoUniqueness), VName) -> ImpM rep r op ())
-> [(Param (MemBound NoUniqueness), VName)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param (MemBound NoUniqueness), VName) -> ImpM rep r op ()
setLoopParam [(LParam rep, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM rep r op ()
doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM rep r op ()
doBody

  [ValueDestination]
pat_dests <- Pat rep -> ImpM rep r op [ValueDestination]
forall rep r op. Pat rep -> ImpM rep r op [ValueDestination]
destinationFromPat Pat rep
pat
  [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([DimSize] -> [(ValueDestination, DimSize)])
-> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> DimSize)
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> DimSize
Var (VName -> DimSize)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> VName)
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName (Param (MemInfo DimSize Uniqueness MemBind) -> VName)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> Param (MemInfo DimSize Uniqueness MemBind))
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(FParam rep, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge) (((ValueDestination, DimSize) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, DimSize) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] DimSize
r []
  where
    params :: [Param (MemInfo DimSize Uniqueness MemBind)]
params = ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
 -> Param (MemInfo DimSize Uniqueness MemBind))
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(FParam rep, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge
defCompileExp Pat rep
pat (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) = do
  [LParam rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
  [(WithAccInput rep, Param (MemBound NoUniqueness))]
-> ((WithAccInput rep, Param (MemBound NoUniqueness))
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([WithAccInput rep]
-> [Param (MemBound NoUniqueness)]
-> [(WithAccInput rep, Param (MemBound NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [WithAccInput rep]
inputs ([Param (MemBound NoUniqueness)]
 -> [(WithAccInput rep, Param (MemBound NoUniqueness))])
-> [Param (MemBound NoUniqueness)]
-> [(WithAccInput rep, Param (MemBound NoUniqueness))]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam) (((WithAccInput rep, Param (MemBound NoUniqueness))
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((WithAccInput rep, Param (MemBound NoUniqueness))
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((ShapeBase DimSize
_, [VName]
arrs, Maybe (Lambda rep, [DimSize])
op), Param (MemBound NoUniqueness)
p) ->
    (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
      ImpState rep r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda rep, [DimSize]))
stateAccs = VName
-> ([VName], Maybe (Lambda rep, [DimSize]))
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param (MemBound NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (MemBound NoUniqueness)
p) ([VName]
arrs, Maybe (Lambda rep, [DimSize])
op) (Map VName ([VName], Maybe (Lambda rep, [DimSize]))
 -> Map VName ([VName], Maybe (Lambda rep, [DimSize])))
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
forall a b. (a -> b) -> a -> b
$ ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
stateAccs ImpState rep r op
s}
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep) -> BodyT rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: Result
nonacc_res = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
num_accs (BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nonacc_res) (Pat rep -> [VName]
forall dec. PatT dec -> [VName]
patNames Pat rep
pat)
    [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names Result
nonacc_res) (((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
_ DimSize
se) ->
      VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
v [] DimSize
se []
  where
    num_accs :: Int
num_accs = [WithAccInput rep] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
defCompileExp Pat rep
pat (Op Op rep
op) = do
  Pat rep -> MemOp inner -> ImpM rep r op ()
opc <- (Env rep r op -> Pat rep -> MemOp inner -> ImpM rep r op ())
-> ImpM rep r op (Pat rep -> MemOp inner -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Pat rep -> MemOp inner -> ImpM rep r op ()
forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler
  Pat rep -> MemOp inner -> ImpM rep r op ()
opc Pat rep
pat Op rep
MemOp inner
op

tracePrim :: String -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim :: String -> PrimType -> DimSize -> ImpM rep r op ()
tracePrim String
s PrimType
t DimSize
se =
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [String -> ErrorMsgPart Exp
forall a. String -> ErrorMsgPart a
ErrorString (String
s String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
": "), PrimType -> Exp -> ErrorMsgPart Exp
forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
t DimSize
se), String -> ErrorMsgPart Exp
forall a. String -> ErrorMsgPart a
ErrorString String
"\n"]

traceArray :: String -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray :: String
-> PrimType -> ShapeBase DimSize -> DimSize -> ImpM rep r op ()
traceArray String
s PrimType
t ShapeBase DimSize
shape DimSize
se = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [String -> ErrorMsgPart Exp
forall a. String -> ErrorMsgPart a
ErrorString (String
s String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
": ")]
  ShapeBase DimSize
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall rep r op.
ShapeBase DimSize
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase DimSize
shape (([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
    TV Any
arr_elem <- String -> PrimType -> ImpM rep r op (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"arr_elem" PrimType
t
    VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
arr_elem) [] DimSize
se [TExp Int64]
is
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [PrimType -> Exp -> ErrorMsgPart Exp
forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (TPrimExp Any VName -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TV Any -> TPrimExp Any VName
forall t. TV t -> TExp t
tvExp TV Any
arr_elem)), ErrorMsgPart Exp
" "]
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [ErrorMsgPart Exp
"\n"]

defCompileBasicOp ::
  Mem rep inner =>
  Pat rep ->
  BasicOp ->
  ImpM rep r op ()
defCompileBasicOp :: Pat rep -> BasicOp -> ImpM rep r op ()
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (SubExp DimSize
se) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [] DimSize
se []
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Opaque OpaqueOp
op DimSize
se) = do
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [] DimSize
se []
  case OpaqueOp
op of
    OpaqueOp
OpaqueNil -> () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    OpaqueTrace String
s -> String -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment (String
"Trace: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
s) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
      Type
se_t <- DimSize -> ImpM rep r op Type
forall t (m :: * -> *). HasScope t m => DimSize -> m Type
subExpType DimSize
se
      case Type
se_t of
        Prim PrimType
t -> String -> PrimType -> DimSize -> ImpM rep r op ()
forall rep r op. String -> PrimType -> DimSize -> ImpM rep r op ()
tracePrim String
s PrimType
t DimSize
se
        Array PrimType
t ShapeBase DimSize
shape NoUniqueness
_ -> String
-> PrimType -> ShapeBase DimSize -> DimSize -> ImpM rep r op ()
forall rep r op.
String
-> PrimType -> ShapeBase DimSize -> DimSize -> ImpM rep r op ()
traceArray String
s PrimType
t ShapeBase DimSize
shape DimSize
se
        Type
_ ->
          [SrcLoc] -> [[SrcLoc]] -> String -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> String -> ImpM rep r op ()
warn [SrcLoc
forall a. Monoid a => a
mempty :: SrcLoc] [[SrcLoc]]
forall a. Monoid a => a
mempty (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            String
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": cannot trace value of this (core) type: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Type -> String
forall a. Pretty a => a -> String
pretty Type
se_t
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (UnOp UnOp
op DimSize
e) = do
  Exp
e' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
e
  PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (ConvOp ConvOp
conv DimSize
e) = do
  Exp
e' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
e
  PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (BinOp BinOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
y
  PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (CmpOp CmpOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
y
  PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp Pat rep
_ (Assert DimSize
e ErrorMsg DimSize
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
e
  ErrorMsg Exp
msg' <- (DimSize -> ImpM rep r op Exp)
-> ErrorMsg DimSize -> ImpM rep r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp ErrorMsg DimSize
msg
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> String -> ImpM rep r op ())
-> (SrcLoc, [SrcLoc]) -> String -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> String -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> String -> ImpM rep r op ()
warn (SrcLoc, [SrcLoc])
loc String
"Safety check required at run-time."
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Index VName
src Slice DimSize
slice)
  | Just [DimSize]
idxs <- Slice DimSize -> Maybe [DimSize]
forall d. Slice d -> Maybe [d]
sliceIndices Slice DimSize
slice =
    VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [] (VName -> DimSize
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM rep r op ())
-> [DimIndex (TExp Int64)] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ (DimSize -> DimIndex (TExp Int64))
-> [DimSize] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (DimSize -> TExp Int64) -> DimSize -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [DimSize]
idxs
defCompileBasicOp Pat rep
_ Index {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Update Safety
safety VName
_ Slice DimSize
slice DimSize
se) =
  case Safety
safety of
    Safety
Unsafe -> ImpM rep r op ()
write
    Safety
Safe -> TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
dims) ImpM rep r op ()
write
  where
    slice' :: Slice (TExp Int64)
slice' = (DimSize -> TExp Int64) -> Slice DimSize -> Slice (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp Slice DimSize
slice
    dims :: [TExp Int64]
dims = (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([DimSize] -> [TExp Int64]) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize]) -> Type -> [DimSize]
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe
    write :: ImpM rep r op ()
write = VName -> Slice (TExp Int64) -> DimSize -> ImpM rep r op ()
forall rep r op.
VName -> Slice (TExp Int64) -> DimSize -> ImpM rep r op ()
sUpdate (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) Slice (TExp Int64)
slice' DimSize
se
defCompileBasicOp Pat rep
_ FlatIndex {} =
  () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (FlatUpdate VName
_ FlatSlice DimSize
slice VName
v) = do
  MemLoc
pe_loc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe)
  MemLoc
v_loc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
v
  CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe)) (MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc MemLoc
pe_loc FlatSlice (TExp Int64)
slice') MemLoc
v_loc
  where
    slice' :: FlatSlice (TExp Int64)
slice' = (DimSize -> TExp Int64)
-> FlatSlice DimSize -> FlatSlice (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp FlatSlice DimSize
slice
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Replicate (Shape [DimSize]
ds) DimSize
se)
  | Acc {} <- PatElemT (LetDec rep) -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT (LetDec rep)
pe = () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise = do
    [Exp]
ds' <- (DimSize -> ImpM rep r op Exp) -> [DimSize] -> ImpM rep r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp [DimSize]
ds
    [VName]
is <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) (String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
    Code op
copy_elem <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64) [VName]
is) DimSize
se []
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp Pat rep
_ Scratch {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Iota DimSize
n DimSize
e DimSize
s IntType
it) = do
  Exp
e' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
e
  Exp
s' <- DimSize -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp DimSize
s
  String
-> TExp Int64
-> (TExp Int64 -> ImpM rep r op ())
-> ImpM rep r op ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
n) ((TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      String -> TPrimExp Any VName -> ImpM rep r op (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"x" (TPrimExp Any VName -> ImpM rep r op (TV Any))
-> (Exp -> TPrimExp Any VName) -> Exp -> ImpM rep r op (TV Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> TPrimExp Any VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> ImpM rep r op (TV Any)) -> Exp -> ImpM rep r op (TV Any)
forall a b. (a -> b) -> a -> b
$
        BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> DimSize
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Copy VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [] (VName -> DimSize
Var VName
src) []
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [] (VName -> DimSize
Var VName
src) []
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (Concat Int
i VName
x [VName]
ys DimSize
_) = do
  TV Int64
offs_glb <- String -> TExp Int64 -> ImpM rep r op (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"tmp_offs" TExp Int64
0

  [VName] -> (VName -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM rep r op ()) -> ImpM rep r op ())
-> (VName -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [DimSize]
y_dims <- Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> ImpM rep r op Type -> ImpM rep r op [DimSize]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
drop Int
i [DimSize]
y_dims of
          [] -> String -> TExp Int64
forall a. HasCallStack => String -> a
error (String -> TExp Int64) -> String -> TExp Int64
forall a b. (a -> b) -> a -> b
$ String
"defCompileBasicOp Concat: empty array shape for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
y
          DimSize
r : [DimSize]
_ -> DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
r
        skip_dims :: [DimSize]
skip_dims = Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
take Int
i [DimSize]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (DimSize -> DimIndex (TExp Int64))
-> [DimSize] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (DimSize -> TExp Int64) -> DimSize -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [DimSize]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [DimIndex (TExp Int64)]
destslice (VName -> DimSize
Var VName
y) []
    TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pat [PatElemT (LetDec rep)
pe]) (ArrayLit [DimSize]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (DimSize -> Maybe PrimValue) -> [DimSize] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> Maybe PrimValue
isLiteral [DimSize]
es = do
    MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe)
    Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
dest_mem)
    let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
    VName
static_array <- String -> ImpM rep r op VName
forall rep r op. String -> ImpM rep r op VName
newVNameForFun String
"static_array"
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
    let static_src :: MemLoc
static_src =
          VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
static_array [IntType -> Integer -> DimSize
intConst IntType
Int64 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es] (IxFun (TExp Int64) -> MemLoc) -> IxFun (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$
            [TExp Int64] -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es]
        entry :: VarEntry rep
entry = Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
    VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
static_array VarEntry rep
entry
    CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
t MemLoc
dest_mem MemLoc
static_src
  | Bool
otherwise =
    [(Integer, DimSize)]
-> ((Integer, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [DimSize] -> [(Integer, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [DimSize]
es) (((Integer, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Integer, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, DimSize
e) ->
      VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] DimSize
e []
  where
    isLiteral :: DimSize -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral DimSize
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp Pat rep
_ Rearrange {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pat rep
_ Rotate {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pat rep
_ Reshape {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pat rep
_ (UpdateAcc VName
acc [DimSize]
is [DimSize]
vs) = String -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"UpdateAcc" (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  -- We are abusing the comment mechanism to wrap the operator in
  -- braces when we end up generating code.  This is necessary because
  -- we might otherwise end up declaring lambda parameters (if any)
  -- multiple times, as they are duplicated every time we do an
  -- UpdateAcc for the same accumulator.
  let is' :: [TExp Int64]
is' = (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (VName
_, Space
_, [VName]
arrs, [TExp Int64]
dims, Maybe (Lambda rep)
op) <- VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
acc [TExp Int64]
is'

  TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
is')) [TExp Int64]
dims) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda rep)
op of
      Maybe (Lambda rep)
Nothing ->
        -- Scatter-like.
        [(VName, DimSize)]
-> ((VName, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [DimSize] -> [(VName, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [DimSize]
vs) (((VName, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, DimSize
v) -> VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' DimSize
v []
      Just Lambda rep
lam -> do
        -- Generalised reduction.
        [LParam rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
        let ([VName]
x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param (MemBound NoUniqueness) -> VName)
-> [Param (MemBound NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemBound NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName ([Param (MemBound NoUniqueness)] -> [VName])
-> [Param (MemBound NoUniqueness)] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam

        [(VName, VName)]
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
xp [] (VName -> DimSize
Var VName
arr) [TExp Int64]
is'

        [(VName, DimSize)]
-> ((VName, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [DimSize] -> [(VName, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [DimSize]
vs) (((VName, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, DimSize) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, DimSize
v) ->
          VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
yp [] DimSize
v []

        Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep) -> BodyT rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (BodyT rep -> Result
forall rep. BodyT rep -> Result
bodyResult (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam))) (((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
_ DimSize
se) ->
            VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' DimSize
se []
defCompileBasicOp Pat rep
pat BasicOp
e =
  String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    String
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat rep -> String
forall a. Pretty a => a -> String
pretty Pat rep
pat
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nfor expression\n  "
      String -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> String
forall a. Pretty a => a -> String
pretty BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays = (ArrayDecl -> ImpM rep r op ()) -> [ArrayDecl] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM rep r op ()
forall rep r op. ArrayDecl -> ImpM rep r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM rep r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLoc
location) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
          Maybe (Exp rep)
forall a. Maybe a
Nothing
          ArrayEntry :: MemLoc -> PrimType -> ArrayEntry
ArrayEntry
            { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem rep inner => [FParam rep] -> ImpM rep r op ()
addFParams :: [FParam rep] -> ImpM rep r op ()
addFParams = (Param (MemInfo DimSize Uniqueness MemBind) -> ImpM rep r op ())
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param (MemInfo DimSize Uniqueness MemBind) -> ImpM rep r op ()
forall u rep r op.
Param (MemInfo DimSize u MemBind) -> ImpM rep r op ()
addFParam
  where
    addFParam :: Param (MemInfo DimSize u MemBind) -> ImpM rep r op ()
addFParam Param (MemInfo DimSize u MemBind)
fparam =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar (Param (MemInfo DimSize u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize u MemBind)
fparam) (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
forall rep.
Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry rep)
-> MemBound NoUniqueness -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo DimSize u MemBind -> MemBound NoUniqueness)
-> MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize u MemBind) -> MemInfo DimSize u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo DimSize u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
i (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  Mem rep inner =>
  Maybe (Exp rep) ->
  [PatElem rep] ->
  ImpM rep r op ()
dVars :: Maybe (Exp rep) -> [PatElem rep] -> ImpM rep r op ()
dVars Maybe (Exp rep)
e = (PatElem rep -> ImpM rep r op ())
-> [PatElem rep] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElem rep -> ImpM rep r op ()
dVar
  where
    dVar :: PatElem rep -> ImpM rep r op ()
dVar = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e (Scope rep -> ImpM rep r op ())
-> (PatElem rep -> Scope rep) -> PatElem rep -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem rep -> Scope rep
forall rep dec. (LetDec rep ~ dec) => PatElemT dec -> Scope rep
scopeOfPatElem

dFParams :: Mem rep inner => [FParam rep] -> ImpM rep r op ()
dFParams :: [FParam rep] -> ImpM rep r op ()
dFParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param (MemInfo DimSize Uniqueness MemBind)] -> Scope rep)
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemInfo DimSize Uniqueness MemBind)] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams

dLParams :: Mem rep inner => [LParam rep] -> ImpM rep r op ()
dLParams :: [LParam rep] -> ImpM rep r op ()
dLParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param (MemBound NoUniqueness)] -> Scope rep)
-> [Param (MemBound NoUniqueness)]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemBound NoUniqueness)] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimVol :: String -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol String
name PrimType
t TExp t
e = do
  VName
name' <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t -> ImpM rep r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM rep r op (TV t)) -> TV t -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM rep r op (TV t)
dPrim :: String -> PrimType -> ImpM rep r op (TV t)
dPrim String
name PrimType
t = do
  VName
name' <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name' PrimType
t
  TV t -> ImpM rep r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM rep r op (TV t)) -> TV t -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op ()
dPrimV_ :: VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimV :: String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
name TExp t
e = do
  TV t
name' <- String -> PrimType -> ImpM rep r op (TV t)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
name (PrimType -> ImpM rep r op (TV t))
-> PrimType -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  TV t -> ImpM rep r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t)
dPrimVE :: String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
name TExp t
e = do
  TV t
name' <- String -> PrimType -> ImpM rep r op (TV t)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
name (PrimType -> ImpM rep r op (TV t))
-> PrimType -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  TExp t -> ImpM rep r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp rep) ->
  MemBound NoUniqueness ->
  VarEntry rep
memBoundToVarEntry :: Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemPrim PrimType
bt) =
  Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp rep)
e (MemMem Space
space) =
  Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
e (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp rep)
e (MemAcc VName
acc ShapeBase DimSize
ispace [Type]
ts NoUniqueness
_) =
  Maybe (Exp rep)
-> (VName, ShapeBase DimSize, [Type]) -> VarEntry rep
forall rep.
Maybe (Exp rep)
-> (VName, ShapeBase DimSize, [Type]) -> VarEntry rep
AccVar Maybe (Exp rep)
e (VName
acc, ShapeBase DimSize
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp rep)
e (MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
_ (ArrayIn VName
mem IxFun (TExp Int64)
ixfun)) =
  let location :: MemLoc
location = VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) IxFun (TExp Int64)
ixfun
   in Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
        Maybe (Exp rep)
e
        ArrayEntry :: MemLoc -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  Mem rep inner =>
  NameInfo rep ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: NameInfo rep -> MemBound NoUniqueness
infoDec (LetName LetDec rep
dec) = LetDec rep -> MemBound NoUniqueness
forall t. HasLetDecMem t => t -> MemBound NoUniqueness
letDecMem LetDec rep
dec
infoDec (FParamName FParamInfo rep
dec) = MemInfo DimSize Uniqueness MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
MemInfo DimSize Uniqueness MemBind
dec
infoDec (LParamName LParamInfo rep
dec) = LParamInfo rep
MemBound NoUniqueness
dec
infoDec (IndexName IntType
it) = PrimType -> MemBound NoUniqueness
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemBound NoUniqueness)
-> PrimType -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  Mem rep inner =>
  Maybe (Exp rep) ->
  VName ->
  NameInfo rep ->
  ImpM rep r op ()
dInfo :: Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e VName
name NameInfo rep
info = do
  let entry :: VarEntry rep
entry = Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
forall rep.
Maybe (Exp rep) -> MemBound NoUniqueness -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemBound NoUniqueness -> VarEntry rep)
-> MemBound NoUniqueness -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ NameInfo rep -> MemBound NoUniqueness
forall rep inner.
Mem rep inner =>
NameInfo rep -> MemBound NoUniqueness
infoDec NameInfo rep
info
  case VarEntry rep
entry of
    MemVar Maybe (Exp rep)
_ MemEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp rep)
_ ScalarEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
_ ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    AccVar {} ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry

dScope ::
  Mem rep inner =>
  Maybe (Exp rep) ->
  Scope rep ->
  ImpM rep r op ()
dScope :: Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e = ((VName, NameInfo rep) -> ImpM rep r op ())
-> [(VName, NameInfo rep)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep) -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo rep -> ImpM rep r op ())
 -> (VName, NameInfo rep) -> ImpM rep r op ())
-> (VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep)
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e) ([(VName, NameInfo rep)] -> ImpM rep r op ())
-> (Scope rep -> [(VName, NameInfo rep)])
-> Scope rep
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope rep -> [(VName, NameInfo rep)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op ()
dArray :: VName
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
dArray VName
name PrimType
pt ShapeBase DimSize
shape VName
mem IxFun (TExp Int64)
ixfun =
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ArrayEntry -> VarEntry rep) -> ArrayEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ MemLoc -> PrimType -> ArrayEntry
ArrayEntry MemLoc
location PrimType
pt
  where
    location :: MemLoc
location =
      VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) IxFun (TExp Int64)
ixfun

everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

-- | Remove the array targets.
funcallTargets :: [ValueDestination] -> ImpM rep r op [VName]
funcallTargets :: [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dests =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM rep r op [[VName]] -> ImpM rep r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM rep r op [VName])
-> [ValueDestination] -> ImpM rep r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM rep r op [VName]
forall (m :: * -> *). Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
    funcallTarget (ArrayDestination Maybe MemLoc
_) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: TV t -> DimSize
tvSize = VName -> DimSize
Var (VName -> DimSize) -> (TV t -> VName) -> TV t -> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM rep r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

  toInt64Exp :: a -> Imp.TExp Int64
  toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64

  toBoolExp :: a -> Imp.TExp Bool
  toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool

instance ToExp SubExp where
  toExp :: DimSize -> ImpM rep r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM rep r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
v ImpM rep r op (VarEntry rep)
-> (VarEntry rep -> ImpM rep r op Exp) -> ImpM rep r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM rep r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry rep
_ -> String -> ImpM rep r op Exp
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op Exp) -> String -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ String
"toExp SubExp: SubExp is not a primitive type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v

  toExp' :: PrimType -> DimSize -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: Exp -> ImpM rep r op Exp
toExp = Exp -> ImpM rep r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  toExp' :: PrimType -> Exp -> Exp
toExp' PrimType
_ = Exp -> Exp
forall a. a -> a
id

addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry =
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VName -> VarEntry rep -> VTable rep -> VTable rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry rep
entry (VTable rep -> VTable rep) -> VTable rep -> VTable rep
forall a b. (a -> b) -> a -> b
$ ImpState rep r op -> VTable rep
forall rep r op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s}

localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace :: Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace Space
space = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM rep r op (Maybe Name)
askFunction :: ImpM rep r op (Maybe Name)
askFunction = (Env rep r op -> Maybe Name) -> ImpM rep r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Maybe Name
forall rep r op. Env rep r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun String
s = do
  Maybe String
fname <- (Name -> String) -> Maybe Name -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> String
nameToString (Maybe Name -> Maybe String)
-> ImpM rep r op (Maybe Name) -> ImpM rep r op (Maybe String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM rep r op VName) -> String -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ String -> ShowS -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
".") Maybe String
fname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM rep r op Name
nameForFun :: String -> ImpM rep r op Name
nameForFun String
s = do
  Maybe Name
fname <- ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  Name -> ImpM rep r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM rep r op Name) -> Name -> ImpM rep r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> String -> Name
nameFromString String
s

askEnv :: ImpM rep r op r
askEnv :: ImpM rep r op r
askEnv = (Env rep r op -> r) -> ImpM rep r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv

localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv r -> r
f = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv Env rep r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM rep r op Attrs
askAttrs :: ImpM rep r op Attrs
askAttrs = (Env rep r op -> Attrs) -> ImpM rep r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs Attrs
attrs = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs Env rep r op
env}

localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations rep r op
ops = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env ->
  Env rep r op
env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Operations rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM rep r op (VTable rep)
getVTable :: ImpM rep r op (VTable rep)
getVTable = (ImpState rep r op -> VTable rep) -> ImpM rep r op (VTable rep)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VTable rep
forall rep r op. ImpState rep r op -> VTable rep
stateVTable

putVTable :: VTable rep -> ImpM rep r op ()
putVTable :: VTable rep -> ImpM rep r op ()
putVTable VTable rep
vtable = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VTable rep
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable VTable rep -> VTable rep
f ImpM rep r op a
m = do
  VTable rep
old_vtable <- ImpM rep r op (VTable rep)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
  VTable rep -> ImpM rep r op ()
forall rep r op. VTable rep -> ImpM rep r op ()
putVTable (VTable rep -> ImpM rep r op ()) -> VTable rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VTable rep -> VTable rep
f VTable rep
old_vtable
  a
a <- ImpM rep r op a
m
  VTable rep -> ImpM rep r op ()
forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
old_vtable
  a -> ImpM rep r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name = do
  Maybe (VarEntry rep)
res <- (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Maybe (VarEntry rep))
 -> ImpM rep r op (Maybe (VarEntry rep)))
-> (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry rep) -> Maybe (VarEntry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry rep) -> Maybe (VarEntry rep))
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Maybe (VarEntry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall rep r op. ImpState rep r op -> VTable rep
stateVTable
  case Maybe (VarEntry rep)
res of
    Just VarEntry rep
entry -> VarEntry rep -> ImpM rep r op (VarEntry rep)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry rep
entry
    Maybe (VarEntry rep)
_ -> String -> ImpM rep r op (VarEntry rep)
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op (VarEntry rep))
-> String -> ImpM rep r op (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray VName
name = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
entry -> ArrayEntry -> ImpM rep r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry rep
_ -> String -> ImpM rep r op ArrayEntry
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ArrayEntry)
-> String -> ImpM rep r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupArray: not an array: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory VName
name = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    MemVar Maybe (Exp rep)
_ MemEntry
entry -> MemEntry -> ImpM rep r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry rep
_ -> String -> ImpM rep r op MemEntry
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op MemEntry)
-> String -> ImpM rep r op MemEntry
forall a b. (a -> b) -> a -> b
$ String
"Unknown memory block: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace =
  (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM rep r op MemEntry -> ImpM rep r op Space)
-> (VName -> ImpM rep r op MemEntry)
-> VName
-> ImpM rep r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
    (VName -> ImpM rep r op Space)
-> (VName -> ImpM rep r op VName) -> VName -> ImpM rep r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM rep r op ArrayEntry -> ImpM rep r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc) (ImpM rep r op ArrayEntry -> ImpM rep r op VName)
-> (VName -> ImpM rep r op ArrayEntry)
-> VName
-> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep))
lookupAcc :: VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
name [TExp Int64]
is = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    AccVar Maybe (Exp rep)
_ (VName
acc, ShapeBase DimSize
ispace, [Type]
_) -> do
      Maybe ([VName], Maybe (Lambda rep, [DimSize]))
acc' <- (ImpState rep r op
 -> Maybe ([VName], Maybe (Lambda rep, [DimSize])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [DimSize])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op
  -> Maybe ([VName], Maybe (Lambda rep, [DimSize])))
 -> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [DimSize]))))
-> (ImpState rep r op
    -> Maybe ([VName], Maybe (Lambda rep, [DimSize])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [DimSize])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
-> Maybe ([VName], Maybe (Lambda rep, [DimSize]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda rep, [DimSize]))
 -> Maybe ([VName], Maybe (Lambda rep, [DimSize])))
-> (ImpState rep r op
    -> Map VName ([VName], Maybe (Lambda rep, [DimSize])))
-> ImpState rep r op
-> Maybe ([VName], Maybe (Lambda rep, [DimSize]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [DimSize]))
stateAccs
      case Maybe ([VName], Maybe (Lambda rep, [DimSize]))
acc' of
        Just ([], Maybe (Lambda rep, [DimSize])
_) ->
          String
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => String -> a
error (String
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> String
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ String
"Accumulator with no arrays: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda rep
op, [DimSize]
_)) -> do
          Space
space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          let ([Param (LParamInfo rep)]
i_params, [Param (LParamInfo rep)]
ps) = Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
is) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
op
          (VName -> TExp Int64 -> ImpM rep r op ())
-> [VName] -> [TExp Int64] -> ImpM rep r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM rep r op ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
i_params) [TExp Int64]
is
          (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( VName
acc,
              Space
space,
              [VName]
arrs,
              (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
ispace),
              Lambda rep -> Maybe (Lambda rep)
forall a. a -> Maybe a
Just Lambda rep
op {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda rep, [DimSize])
Nothing) -> do
          Space
space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, Space
space, [VName]
arrs, (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
ispace), Maybe (Lambda rep)
forall a. Maybe a
Nothing)
        Maybe ([VName], Maybe (Lambda rep, [DimSize]))
Nothing ->
          String
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => String -> a
error (String
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> String
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupAcc: unlisted accumulator: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
    VarEntry rep
_ -> String
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => String -> a
error (String
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> String
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupAcc: not an accumulator: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

destinationFromPat :: Pat rep -> ImpM rep r op [ValueDestination]
destinationFromPat :: Pat rep -> ImpM rep r op [ValueDestination]
destinationFromPat = (PatElemT (LetDec rep) -> ImpM rep r op ValueDestination)
-> [PatElemT (LetDec rep)] -> ImpM rep r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (LetDec rep) -> ImpM rep r op ValueDestination
forall dec rep r op. PatElemT dec -> ImpM rep r op ValueDestination
inspect ([PatElemT (LetDec rep)] -> ImpM rep r op [ValueDestination])
-> (Pat rep -> [PatElemT (LetDec rep)])
-> Pat rep
-> ImpM rep r op [ValueDestination]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat rep -> [PatElemT (LetDec rep)]
forall dec. PatT dec -> [PatElemT dec]
patElems
  where
    inspect :: PatElemT dec -> ImpM rep r op ValueDestination
inspect PatElemT dec
pe = do
      let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe
      VarEntry rep
entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
      case VarEntry rep
entry of
        ArrayVar Maybe (Exp rep)
_ (ArrayEntry MemLoc {} PrimType
_) ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name [TExp Int64]
indices = do
  ArrayEntry
arr <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name
  MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
indices

fullyIndexArray' ::
  MemLoc ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLoc VName
mem [DimSize]
_ IxFun (TExp Int64)
ixfun) [TExp Int64]
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
  (VName, Space, Count Elements (TExp Int64))
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( VName
mem,
      Space
space,
      TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> [TExp Int64] -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun [TExp Int64]
indices
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler rep r op
copy :: CopyCompiler rep r op
copy PrimType
bt MemLoc
dest MemLoc
src =
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless
    ( MemLoc -> VName
memLocName MemLoc
dest VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== MemLoc -> VName
memLocName MemLoc
src
        Bool -> Bool -> Bool
&& MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
dest IxFun (TExp Int64) -> IxFun (TExp Int64) -> Bool
forall num. Eq num => IxFun num -> IxFun num -> Bool
`IxFun.equivalent` MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src
    )
    (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
      CopyCompiler rep r op
cc <- (Env rep r op -> CopyCompiler rep r op)
-> ImpM rep r op (CopyCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> CopyCompiler rep r op
forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler
      CopyCompiler rep r op
cc PrimType
bt MemLoc
dest MemLoc
src

-- | Is this copy really a mapping with transpose?
isMapTransposeCopy ::
  PrimType ->
  MemLoc ->
  MemLoc ->
  Maybe
    ( Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64
    )
isMapTransposeCopy :: PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
bt (MemLoc VName
_ [DimSize]
_ IxFun (TExp Int64)
destIxFun) (MemLoc VName
_ [DimSize]
_ IxFun (TExp Int64)
srcIxFun)
  | Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
    ([Int]
perm, [TExp Int64]
destshape) <- [(Int, TExp Int64)] -> ([Int], [TExp Int64])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
    Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
    [TExp Int64]
-> (([TExp Int64], [TExp Int64]) -> ([TExp Int64], [TExp Int64]))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [TExp Int64]
destshape ([TExp Int64], [TExp Int64]) -> ([TExp Int64], [TExp Int64])
forall b a. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
  | Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
    Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size,
    ([Int]
perm, [TExp Int64]
srcshape) <- [(Int, TExp Int64)] -> ([Int], [TExp Int64])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
    [TExp Int64]
-> (([TExp Int64], [TExp Int64]) -> ([TExp Int64], [TExp Int64]))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [TExp Int64]
srcshape ([TExp Int64], [TExp Int64]) -> ([TExp Int64], [TExp Int64])
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
  | Bool
otherwise =
    Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
  where
    bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
    swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)

    isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
      let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall (t :: * -> *) (t :: * -> *) a b c.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
      (a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( a
dest_offset,
          b
src_offset,
          c
num_arrays,
          d
size_x,
          e
size_y
        )

    getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
      let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
          (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
       in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> String
mapTransposeName PrimType
bt = String
"map_transpose_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt

mapTransposeForType :: PrimType -> ImpM rep r op Name
mapTransposeForType :: PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM rep r op Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt

  Name -> ImpM rep r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler rep r op
defaultCopy :: CopyCompiler rep r op
defaultCopy PrimType
pt MemLoc
dest MemLoc
src
  | Just (TExp Int64
destoffset, TExp Int64
srcoffset, TExp Int64
num_arrays, TExp Int64
size_x, TExp Int64
size_y) <-
      PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLoc
dest MemLoc
src = do
    Name
fname <- PrimType -> ImpM rep r op Name
forall rep r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
pt
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
        []
        Name
fname
        ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
          PrimType
pt
          VName
destmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          VName
srcmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          TExp Int64
num_arrays
          TExp Int64
size_x
          TExp Int64
size_y
  | Just TExp Int64
destoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
dest_ixfun TExp Int64
pt_size,
    Just TExp Int64
srcoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
src_ixfun TExp Int64
pt_size = do
    Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
    Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
    if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
      then CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLoc
dest MemLoc
src
      else
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
            Space
destspace
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
            Space
srcspace
            (Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
  | Bool
otherwise =
    CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLoc
dest MemLoc
src
  where
    pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> [TExp Int64]
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape (IxFun (TExp Int64) -> [TExp Int64])
-> IxFun (TExp Int64) -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src

    MemLoc VName
destmem [DimSize]
_ IxFun (TExp Int64)
dest_ixfun = MemLoc
dest
    MemLoc VName
srcmem [DimSize]
_ IxFun (TExp Int64)
src_ixfun = MemLoc
src

    isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
    isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler rep r op
copyElementWise :: CopyCompiler rep r op
copyElementWise PrimType
bt MemLoc
dest MemLoc
src = do
  let bounds :: [TExp Int64]
bounds = IxFun (TExp Int64) -> [TExp Int64]
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape (IxFun (TExp Int64) -> [TExp Int64])
-> IxFun (TExp Int64) -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src
  [VName]
is <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
bounds) (String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
  let ivars :: [TExp Int64]
ivars = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is
  (VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <- MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest [TExp Int64]
ivars
  (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <- MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
src [TExp Int64]
ivars
  Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
  VName
tmp <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"tmp"
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> [TExp Int64] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped [TExp Int64]
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      [Code op] -> Code op
forall a. Monoid a => [a] -> a
mconcat
        [ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
vol PrimType
bt,
          VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code op
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol,
          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
        ]

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op (Imp.Code op)
copyArrayDWIM :: PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLoc
destlocation@(MemLoc VName
_ [DimSize]
destshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLoc
srclocation@(MemLoc VName
_ [DimSize]
srcshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just [TExp Int64]
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just [TExp Int64]
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      [TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
srcshape,
      [TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
destshape = do
      (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
        MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
destlocation [TExp Int64]
destis
      (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
        MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
srclocation [TExp Int64]
srcis
      Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
      ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
        VName
tmp <- TV Any -> VName
forall t. TV t -> VName
tvVar (TV Any -> VName) -> ImpM rep r op (TV Any) -> ImpM rep r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> ImpM rep r op (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"tmp" PrimType
bt
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code op
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
    | Bool
otherwise = do
      let destslice' :: Slice (TExp Int64)
destslice' = [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
destshape) [DimIndex (TExp Int64)]
destslice
          srcslice' :: Slice (TExp Int64)
srcslice' = [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
srcshape) [DimIndex (TExp Int64)]
srcslice
          destrank :: Int
destrank = [TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TExp Int64] -> Int) -> [TExp Int64] -> Int
forall a b. (a -> b) -> a -> b
$ Slice (TExp Int64) -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
destslice'
          srcrank :: Int
srcrank = [TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TExp Int64] -> Int) -> [TExp Int64] -> Int
forall a b. (a -> b) -> a -> b
$ Slice (TExp Int64) -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
srcslice'
          destlocation' :: MemLoc
destlocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
destlocation Slice (TExp Int64)
destslice'
          srclocation' :: MemLoc
srclocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
srclocation Slice (TExp Int64)
srcslice'
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then
          String -> ImpM rep r op (Code op)
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op (Code op))
-> String -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
            String
"copyArrayDWIM: cannot copy to "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLoc -> VName
memLocName MemLoc
destlocation)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" from "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLoc -> VName
memLocName MemLoc
srclocation)
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" because ranks do not match ("
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
destrank
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" vs "
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
srcrank
              String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
        else
          if MemLoc
destlocation' MemLoc -> MemLoc -> Bool
forall a. Eq a => a -> a -> Bool
== MemLoc
srclocation'
            then Code op -> ImpM rep r op (Code op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
            else ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
bt MemLoc
destlocation' MemLoc
srclocation'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIMDest :: ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe [TExp Int64]
Nothing ->
      String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"with slice destination."]
    Just [TExp Int64]
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be written to memory destination."]
        ArrayDestination (Just MemLoc
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLoc
Nothing ->
          String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error String
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
src_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry rep
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp rep)
_ (MemEntry Space
space)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry rep
_) ->
      String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [String] -> String
unwords [String
"copyDWIMDest: cannot write", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [String] -> String
unwords [String
"copyDWIMDest: source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
        String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"with slice", [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry rep
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
        String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: prim-typed target", VName -> String
forall a. Pretty a => a -> String
pretty VName
name, String
"with slice", [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp rep)
_ ArrayEntry
arr)
      | Just [TExp Int64]
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arr) -> do
        let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
        (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
          MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
src_is
        Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code op
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
name VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
        String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords
            [ String
"copyDWIMDest: prim-typed target",
              VName -> String
forall a. Pretty a => a -> String
pretty VName
name,
              String
"and array-typed source",
              VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
              String
"of shape",
              [DimSize] -> String
forall a. Pretty a => a -> String
pretty (ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arr),
              String
"sliced with",
              [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
src_slice
            ]
    (ArrayDestination (Just MemLoc
dest_loc), ArrayVar Maybe (Exp rep)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLoc
src_loc = ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ImpM rep r op (Code op) -> ImpM rep r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM PrimType
bt MemLoc
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLoc
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLoc
dest_loc), ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
bt))
      | Just [TExp Int64]
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice,
        [TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dest_is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLoc -> [DimSize]
memLocShape MemLoc
dest_loc) -> do
        (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
        Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
        String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords
            [ String
"copyDWIMDest: array-typed target and prim-typed source",
              VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
              String
"with slice",
              [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
dest_slice
            ]
    (ArrayDestination Maybe MemLoc
Nothing, VarEntry rep
_) ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIM :: VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice DimSize
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
dest_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry rep
dest_entry of
          ScalarVar Maybe (Exp rep)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp rep)
_ (ArrayEntry (MemLoc VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLoc -> ValueDestination
ArrayDestination (Maybe MemLoc -> ValueDestination)
-> Maybe MemLoc -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLoc -> Maybe MemLoc
forall a. a -> Maybe a
Just (MemLoc -> Maybe MemLoc) -> MemLoc -> Maybe MemLoc
forall a b. (a -> b) -> a -> b
$ VName -> [DimSize] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun
          MemVar Maybe (Exp rep)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing
  ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice DimSize
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM rep r op ()
copyDWIMFix :: VName
-> [TExp Int64] -> DimSize -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64]
dest_is DimSize
src [TExp Int64]
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
dest_is) DimSize
src ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc ::
  Mem rep inner => Pat rep -> SubExp -> Space -> ImpM rep r op ()
compileAlloc :: Pat rep -> DimSize -> Space -> ImpM rep r op ()
compileAlloc (Pat [PatElemT (LetDec rep)
mem]) DimSize
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
e
  Maybe (AllocCompiler rep r op)
allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' (PatElemT (LetDec rep) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pat rep
pat DimSize
_ Space
_ =
  String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ String
"compileAlloc: Invalid pattern: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat rep -> String
forall a. Pretty a => a -> String
pretty Pat rep
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (Slice [DimIndex (TExp Int64)]
slice) [TExp Int64]
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ (TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
- TPrimExp t v
1) TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall t v.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice [TExp Int64]
dims

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' :: VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound ImpM rep r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> String -> IntType
forall a. HasCallStack => String -> a
error (String -> IntType) -> String -> IntType
forall a b. (a -> b) -> a -> b
$ String
"sFor': bound " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
bound String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
t
  VName -> IntType -> ImpM rep r op ()
forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor :: String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
i TExp t
bound TExp t -> ImpM rep r op ()
body = do
  VName
i' <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
i
  VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM rep r op ()
body (TExp t -> ImpM rep r op ()) -> TExp t -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile :: TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
cond ImpM rep r op ()
body = do
  Code op
body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: String -> ImpM rep r op () -> ImpM rep r op ()
sComment :: String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
s ImpM rep r op ()
code = do
  Code op
code' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
code
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf :: TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch ImpM rep r op ()
fbranch = do
  Code op
tbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
tbranch
  Code op
fbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
fbranch
  -- Avoid generating branch if the condition is known statically.
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    if TExp Bool
cond TExp Bool -> TExp Bool -> Bool
forall a. Eq a => a -> a -> Bool
== TExp Bool
forall v. TPrimExp Bool v
true
      then Code op
tbranch'
      else
        if TExp Bool
cond TExp Bool -> TExp Bool -> Bool
forall a. Eq a => a -> a -> Bool
== TExp Bool
forall v. TPrimExp Bool v
false
          then Code op
fbranch'
          else TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen :: TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
cond ImpM rep r op ()
tbranch = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch (() -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless :: TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond (() -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM rep r op ()
sOp :: op -> ImpM rep r op ()
sOp = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (op -> Code op) -> op -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem String
name Space
space = do
  VName
name' <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM rep r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ :: VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler rep r op)
allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName
sAlloc :: String -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc String
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- String -> Space -> ImpM rep r op VName
forall rep r op. String -> Space -> ImpM rep r op VName
sDeclareMem String
name Space
space
  VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  VName -> ImpM rep r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op VName
sArray :: String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray String
name PrimType
bt ShapeBase DimSize
shape VName
mem IxFun (TExp Int64)
ixfun = do
  VName
name' <- String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  VName
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
forall rep r op.
VName
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
dArray VName
name' PrimType
bt ShapeBase DimSize
shape VName
mem IxFun (TExp Int64)
ixfun
  VName -> ImpM rep r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem :: String
-> PrimType -> ShapeBase DimSize -> VName -> ImpM rep r op VName
sArrayInMem String
name PrimType
pt ShapeBase DimSize
shape VName
mem =
  String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape VName
mem (IxFun (TExp Int64) -> ImpM rep r op VName)
-> IxFun (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    [TExp Int64] -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TExp Int64] -> IxFun (TExp Int64))
-> [TExp Int64] -> IxFun (TExp Int64)
forall a b. (a -> b) -> a -> b
$ (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> TExp Int64) -> (DimSize -> Exp) -> DimSize -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> Exp
primExpFromSubExp PrimType
int64) ([DimSize] -> [TExp Int64]) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm :: String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM rep r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int]
perm = do
  let permuted_dims :: [DimSize]
permuted_dims = [Int] -> [DimSize] -> [DimSize]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([DimSize] -> [DimSize]) -> [DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape
  VName
mem <- String -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
forall rep r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc (String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase DimSize
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun (TExp Int64)
iota_ixfun = [TExp Int64] -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TExp Int64] -> IxFun (TExp Int64))
-> [TExp Int64] -> IxFun (TExp Int64)
forall a b. (a -> b) -> a -> b
$ (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> TExp Int64) -> (DimSize -> Exp) -> DimSize -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> Exp
primExpFromSubExp PrimType
int64) [DimSize]
permuted_dims
  String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape VName
mem (IxFun (TExp Int64) -> ImpM rep r op VName)
-> IxFun (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    IxFun (TExp Int64) -> [Int] -> IxFun (TExp Int64)
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TExp Int64)
iota_ixfun ([Int] -> IxFun (TExp Int64)) -> [Int] -> IxFun (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray :: String
-> PrimType -> ShapeBase DimSize -> Space -> ImpM rep r op VName
sAllocArray String
name PrimType
pt ShapeBase DimSize
shape Space
space =
  String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM rep r op VName
forall rep r op.
String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM rep r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int
0 .. ShapeBase DimSize -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase DimSize
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName
sStaticArray :: String -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray String
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: ShapeBase DimSize
shape = [DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> DimSize
intConst IntType
Int64 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- String -> ImpM rep r op VName
forall rep r op. String -> ImpM rep r op VName
newVNameForFun (String -> ImpM rep r op VName) -> String -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem"
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
mem (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
String
-> PrimType
-> ShapeBase DimSize
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape VName
mem (IxFun (TExp Int64) -> ImpM rep r op VName)
-> IxFun (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite :: VName -> [TExp Int64] -> Exp -> ImpM rep r op ()
sWrite VName
arr [TExp Int64]
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
is
  Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate :: VName -> Slice (TExp Int64) -> DimSize -> ImpM rep r op ()
sUpdate VName
arr Slice (TExp Int64)
slice DimSize
v = VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr (Slice (TExp Int64) -> [DimIndex (TExp Int64)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice) DimSize
v []

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopNest :: ShapeBase DimSize
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest = [TExp Int64]
-> [DimSize]
-> ([TExp Int64] -> ImpM rep r op ())
-> ImpM rep r op ()
forall a rep r op.
ToExp a =>
[TExp Int64]
-> [a] -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest' [] ([DimSize]
 -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> (ShapeBase DimSize -> [DimSize])
-> ShapeBase DimSize
-> ([TExp Int64] -> ImpM rep r op ())
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims
  where
    sLoopNest' :: [TExp Int64]
-> [a] -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest' [TExp Int64]
is [] [TExp Int64] -> ImpM rep r op ()
f = [TExp Int64] -> ImpM rep r op ()
f ([TExp Int64] -> ImpM rep r op ())
-> [TExp Int64] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
reverse [TExp Int64]
is
    sLoopNest' [TExp Int64]
is (a
d : [a]
ds) [TExp Int64] -> ImpM rep r op ()
f =
      String
-> TExp Int64
-> (TExp Int64 -> ImpM rep r op ())
-> ImpM rep r op ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> [TExp Int64]
-> [a] -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> [TExp Int64] -> [TExp Int64]
forall a. a -> [a] -> [a]
: [TExp Int64]
is) [a]
ds [TExp Int64] -> ImpM rep r op ()
f

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM rep r op ()
VName
x <~~ :: VName -> Exp -> ImpM rep r op ()
<~~ Exp
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM rep r op ()
TV VName
x PrimType
_ <-- :: TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM rep r op () ->
  ImpM rep r op ()
function :: Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM rep r op ()
m = (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env rep r op -> Env rep r op
newFunction (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM rep r op ()
forall rep r op. Param -> ImpM rep r op ()
addParam ([Param] -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM rep r op ()
m
  Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [(Name, ExternalValue)]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [(Name, ExternalValue)]
-> FunctionT a
Imp.Function Maybe Name
forall a. Maybe a
Nothing [Param]
outputs [Param]
inputs Code op
body [] []
  where
    addParam :: Param -> ImpM rep r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env rep r op -> Env rep r op
newFunction Env rep r op
env = Env rep r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}

dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64]
dSlices :: [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices = ((TExp Int64, [TExp Int64]) -> [TExp Int64])
-> ImpM rep r op (TExp Int64, [TExp Int64])
-> ImpM rep r op [TExp Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [TExp Int64] -> [TExp Int64]
forall a. Int -> [a] -> [a]
drop Int
1 ([TExp Int64] -> [TExp Int64])
-> ((TExp Int64, [TExp Int64]) -> [TExp Int64])
-> (TExp Int64, [TExp Int64])
-> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64, [TExp Int64]) -> [TExp Int64]
forall a b. (a, b) -> b
snd) (ImpM rep r op (TExp Int64, [TExp Int64])
 -> ImpM rep r op [TExp Int64])
-> ([TExp Int64] -> ImpM rep r op (TExp Int64, [TExp Int64]))
-> [TExp Int64]
-> ImpM rep r op [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> ImpM rep r op (TExp Int64, [TExp Int64])
forall t rep r op.
NumExp t =>
[TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices'
  where
    dSlices' :: [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [] = (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
    dSlices' (TExp t
n : [TExp t]
ns) = do
      (TExp t
prod, [TExp t]
ns') <- [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
      TExp t
n' <- String -> TExp t -> ImpM rep r op (TExp t)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"slice" (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
n TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
prod
      (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
n', TExp t
n' TExp t -> [TExp t] -> [TExp t]
forall a. a -> [a] -> [a]
: [TExp t]
ns')

-- | @dIndexSpace f dims i@ computes a list of indices into an
-- array with dimension @dims@ given the flat index @i@.  The
-- resulting list will have the same size as @dims@.  Intermediate
-- results are passed to @f@.
dIndexSpace ::
  [(VName, Imp.TExp Int64)] ->
  Imp.TExp Int64 ->
  ImpM rep r op ()
dIndexSpace :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
  [TExp Int64]
slices <- [TExp Int64] -> ImpM rep r op [TExp Int64]
forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices (((VName, TExp Int64) -> TExp Int64)
-> [(VName, TExp Int64)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> TExp Int64
forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
  [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, TExp Int64) -> VName) -> [(VName, TExp Int64)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> VName
forall a b. (a, b) -> a
fst [(VName, TExp Int64)]
vs_ds) [TExp Int64]
slices) TExp Int64
j
  where
    loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
      VName -> TExp Int64 -> ImpM rep r op ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
      TExp Int64
i' <- String -> TExp Int64 -> ImpM rep r op (TExp Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"remnant" (TExp Int64 -> ImpM rep r op (TExp Int64))
-> TExp Int64 -> ImpM rep r op (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
v TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
size
      [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop [(VName, TExp Int64)]
rest TExp Int64
i'
    loop [(VName, TExp Int64)]
_ TExp Int64
_ = () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Like 'dIndexSpace', but invent some new names for the indexes
-- based on the given template.
dIndexSpace' ::
  String ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  ImpM rep r op [Imp.TExp Int64]
dIndexSpace' :: String -> [TExp Int64] -> TExp Int64 -> ImpM rep r op [TExp Int64]
dIndexSpace' String
desc [TExp Int64]
ds TExp Int64
j = do
  [VName]
ivs <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
ds) (String -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc)
  [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ivs [TExp Int64]
ds) TExp Int64
j
  [TExp Int64] -> ImpM rep r op [TExp Int64]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TExp Int64] -> ImpM rep r op [TExp Int64])
-> [TExp Int64] -> ImpM rep r op [TExp Int64]
forall a b. (a -> b) -> a -> b
$ (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ivs