-- | Change 'DefaultSpace' in a program to some other memory space.
-- This is needed because the GPU backends use 'DefaultSpace' to refer
-- to GPU memory for most of the pipeline, but final code generation
-- assumes that 'DefaultSpace' is CPU memory.
module Futhark.CodeGen.SetDefaultSpace
  ( setDefaultSpace,
    setDefaultCodeSpace,
  )
where

import Futhark.CodeGen.ImpCode

-- | Set all uses of 'DefaultSpace' in the given definitions to another
-- memory space.
setDefaultSpace :: Space -> Definitions op -> Definitions op
setDefaultSpace :: forall op. Space -> Definitions op -> Definitions op
setDefaultSpace Space
space (Definitions OpaqueTypes
types (Constants [Param]
ps Code op
consts) (Functions [(Name, Function op)]
fundecs)) =
  forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
Definitions
    OpaqueTypes
types
    (forall a. [Param] -> Code a -> Constants a
Constants (forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
ps) (forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
consts))
    ( forall a. [(Name, Function a)] -> Functions a
Functions
        [ (Name
fname, forall op. Space -> Function op -> Function op
setFunctionSpace Space
space Function op
func)
          | (Name
fname, Function op
func) <- [(Name, Function op)]
fundecs
        ]
    )

-- | Like 'setDefaultSpace', but for 'Code'.
setDefaultCodeSpace :: Space -> Code op -> Code op
setDefaultCodeSpace :: forall op. Space -> Code op -> Code op
setDefaultCodeSpace = forall op. Space -> Code op -> Code op
setCodeSpace

setFunctionSpace :: Space -> Function op -> Function op
setFunctionSpace :: forall op. Space -> Function op -> Function op
setFunctionSpace Space
space (Function Maybe EntryPoint
entry [Param]
outputs [Param]
inputs Code op
body) =
  forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Function
    (Space -> EntryPoint -> EntryPoint
setEntrySpace Space
space forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe EntryPoint
entry)
    (forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
outputs)
    (forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
inputs)
    (forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
body)

setEntrySpace :: Space -> EntryPoint -> EntryPoint
setEntrySpace :: Space -> EntryPoint -> EntryPoint
setEntrySpace Space
space (EntryPoint Name
name [(Uniqueness, ExternalValue)]
results [((Name, Uniqueness), ExternalValue)]
args) =
  Name
-> [(Uniqueness, ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> EntryPoint
EntryPoint
    Name
name
    (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [(Uniqueness, ExternalValue)]
results)
    (forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [((Name, Uniqueness), ExternalValue)]
args)

setParamSpace :: Space -> Param -> Param
setParamSpace :: Space -> Param -> Param
setParamSpace Space
space (MemParam VName
name Space
DefaultSpace) =
  VName -> Space -> Param
MemParam VName
name Space
space
setParamSpace Space
_ Param
param =
  Param
param

setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space (OpaqueValue String
desc [ValueDesc]
vs) =
  String -> [ValueDesc] -> ExternalValue
OpaqueValue String
desc forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (Space -> ValueDesc -> ValueDesc
setValueSpace Space
space) [ValueDesc]
vs
setExtValueSpace Space
space (TransparentValue ValueDesc
v) =
  ValueDesc -> ExternalValue
TransparentValue forall a b. (a -> b) -> a -> b
$ Space -> ValueDesc -> ValueDesc
setValueSpace Space
space ValueDesc
v

setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace Space
space (ArrayValue VName
mem Space
_ PrimType
bt Signedness
ept [DimSize]
shape) =
  VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
ArrayValue VName
mem Space
space PrimType
bt Signedness
ept [DimSize]
shape
setValueSpace Space
_ (ScalarValue PrimType
bt Signedness
ept VName
v) =
  PrimType -> Signedness -> VName -> ValueDesc
ScalarValue PrimType
bt Signedness
ept VName
v

setCodeSpace :: Space -> Code op -> Code op
setCodeSpace :: forall op. Space -> Code op -> Code op
setCodeSpace Space
space (Allocate VName
v Count Bytes (TExp Int64)
e Space
old_space) =
  forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Allocate VName
v Count Bytes (TExp Int64)
e forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
space (Free VName
v Space
old_space) =
  forall a. VName -> Space -> Code a
Free VName
v forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
space (DeclareMem VName
name Space
old_space) =
  forall a. VName -> Space -> Code a
DeclareMem VName
name forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
space (DeclareArray VName
name Space
_ PrimType
t ArrayContents
vs) =
  forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
DeclareArray VName
name Space
space PrimType
t ArrayContents
vs
setCodeSpace Space
space (Copy PrimType
t VName
dest Count Bytes (TExp Int64)
dest_offset Space
dest_space VName
src Count Bytes (TExp Int64)
src_offset Space
src_space Count Bytes (TExp Int64)
n) =
  forall a.
PrimType
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Copy PrimType
t VName
dest Count Bytes (TExp Int64)
dest_offset Space
dest_space' VName
src Count Bytes (TExp Int64)
src_offset Space
src_space' Count Bytes (TExp Int64)
n
  where
    dest_space' :: Space
dest_space' = Space -> Space -> Space
setSpace Space
space Space
dest_space
    src_space' :: Space
src_space' = Space -> Space -> Space
setSpace Space
space Space
src_space
setCodeSpace Space
space (Write VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt Space
dest_space Volatility
vol Exp
e) =
  forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt (Space -> Space -> Space
setSpace Space
space Space
dest_space) Volatility
vol Exp
e
setCodeSpace Space
space (Read VName
x VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt Space
dest_space Volatility
vol) =
  forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read VName
x VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt (Space -> Space -> Space
setSpace Space
space Space
dest_space) Volatility
vol
setCodeSpace Space
space (Code op
c1 :>>: Code op
c2) =
  forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c1 forall a. Code a -> Code a -> Code a
:>>: forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c2
setCodeSpace Space
space (For VName
i Exp
e Code op
body) =
  forall a. VName -> Exp -> Code a -> Code a
For VName
i Exp
e forall a b. (a -> b) -> a -> b
$ forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
body
setCodeSpace Space
space (While TExp Bool
e Code op
body) =
  forall a. TExp Bool -> Code a -> Code a
While TExp Bool
e forall a b. (a -> b) -> a -> b
$ forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
body
setCodeSpace Space
space (If TExp Bool
e Code op
c1 Code op
c2) =
  forall a. TExp Bool -> Code a -> Code a -> Code a
If TExp Bool
e (forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c1) (forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c2)
setCodeSpace Space
space (Comment Text
s Code op
c) =
  forall a. Text -> Code a -> Code a
Comment Text
s forall a b. (a -> b) -> a -> b
$ forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c
setCodeSpace Space
_ Code op
Skip =
  forall a. Code a
Skip
setCodeSpace Space
_ (DeclareScalar VName
name Volatility
vol PrimType
bt) =
  forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
name Volatility
vol PrimType
bt
setCodeSpace Space
_ (SetScalar VName
name Exp
e) =
  forall a. VName -> Exp -> Code a
SetScalar VName
name Exp
e
setCodeSpace Space
space (SetMem VName
to VName
from Space
old_space) =
  forall a. VName -> VName -> Space -> Code a
SetMem VName
to VName
from forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
_ (Call [VName]
dests Name
fname [Arg]
args) =
  forall a. [VName] -> Name -> [Arg] -> Code a
Call [VName]
dests Name
fname [Arg]
args
setCodeSpace Space
_ (Assert Exp
e ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc) =
  forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Assert Exp
e ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc
setCodeSpace Space
_ (DebugPrint String
s Maybe Exp
v) =
  forall a. String -> Maybe Exp -> Code a
DebugPrint String
s Maybe Exp
v
setCodeSpace Space
_ (TracePrint ErrorMsg Exp
msg) =
  forall a. ErrorMsg Exp -> Code a
TracePrint ErrorMsg Exp
msg
setCodeSpace Space
_ (Op op
op) =
  forall a. a -> Code a
Op op
op

setSpace :: Space -> Space -> Space
setSpace :: Space -> Space -> Space
setSpace Space
space Space
DefaultSpace = Space
space
setSpace Space
_ Space
space = Space
space