-- | Change 'DefaultSpace' in a program to some other memory space.
-- This is needed because the GPU backends use 'DefaultSpace' to refer
-- to GPU memory for most of the pipeline, but final code generation
-- assumes that 'DefaultSpace' is CPU memory.
module Futhark.CodeGen.SetDefaultSpace
  ( setDefaultSpace,
    setDefaultCodeSpace,
  )
where

import Futhark.CodeGen.ImpCode

-- | Set all uses of 'DefaultSpace' in the given definitions to another
-- memory space.
setDefaultSpace :: Space -> Definitions op -> Definitions op
setDefaultSpace :: Space -> Definitions op -> Definitions op
setDefaultSpace Space
space (Definitions (Constants [Param]
ps Code op
consts) (Functions [(Name, Function op)]
fundecs)) =
  Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Definitions
    ([Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Constants ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
ps) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
consts))
    ( [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions
        [ (Name
fname, Space -> Function op -> Function op
forall op. Space -> Function op -> Function op
setFunctionSpace Space
space Function op
func)
          | (Name
fname, Function op
func) <- [(Name, Function op)]
fundecs
        ]
    )

-- | Like 'setDefaultSpace', but for 'Code'.
setDefaultCodeSpace :: Space -> Code op -> Code op
setDefaultCodeSpace :: Space -> Code op -> Code op
setDefaultCodeSpace = Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace

setFunctionSpace :: Space -> Function op -> Function op
setFunctionSpace :: Space -> Function op -> Function op
setFunctionSpace Space
space (Function Maybe Name
entry [Param]
outputs [Param]
inputs Code op
body [ExternalValue]
results [(Name, ExternalValue)]
args) =
  Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [(Name, ExternalValue)]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [(Name, ExternalValue)]
-> FunctionT a
Function
    Maybe Name
entry
    ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
outputs)
    ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
inputs)
    (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
body)
    ((ExternalValue -> ExternalValue)
-> [ExternalValue] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [ExternalValue]
results)
    (((Name, ExternalValue) -> (Name, ExternalValue))
-> [(Name, ExternalValue)] -> [(Name, ExternalValue)]
forall a b. (a -> b) -> [a] -> [b]
map ((ExternalValue -> ExternalValue)
-> (Name, ExternalValue) -> (Name, ExternalValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ExternalValue -> ExternalValue)
 -> (Name, ExternalValue) -> (Name, ExternalValue))
-> (ExternalValue -> ExternalValue)
-> (Name, ExternalValue)
-> (Name, ExternalValue)
forall a b. (a -> b) -> a -> b
$ Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [(Name, ExternalValue)]
args)

setParamSpace :: Space -> Param -> Param
setParamSpace :: Space -> Param -> Param
setParamSpace Space
space (MemParam VName
name Space
DefaultSpace) =
  VName -> Space -> Param
MemParam VName
name Space
space
setParamSpace Space
_ Param
param =
  Param
param

setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space (OpaqueValue Uniqueness
u String
desc [ValueDesc]
vs) =
  Uniqueness -> String -> [ValueDesc] -> ExternalValue
OpaqueValue Uniqueness
u String
desc ([ValueDesc] -> ExternalValue) -> [ValueDesc] -> ExternalValue
forall a b. (a -> b) -> a -> b
$ (ValueDesc -> ValueDesc) -> [ValueDesc] -> [ValueDesc]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ValueDesc -> ValueDesc
setValueSpace Space
space) [ValueDesc]
vs
setExtValueSpace Space
space (TransparentValue Uniqueness
u ValueDesc
v) =
  Uniqueness -> ValueDesc -> ExternalValue
TransparentValue Uniqueness
u (ValueDesc -> ExternalValue) -> ValueDesc -> ExternalValue
forall a b. (a -> b) -> a -> b
$ Space -> ValueDesc -> ValueDesc
setValueSpace Space
space ValueDesc
v

setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace Space
space (ArrayValue VName
mem Space
_ PrimType
bt Signedness
ept [DimSize]
shape) =
  VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
ArrayValue VName
mem Space
space PrimType
bt Signedness
ept [DimSize]
shape
setValueSpace Space
_ (ScalarValue PrimType
bt Signedness
ept VName
v) =
  PrimType -> Signedness -> VName -> ValueDesc
ScalarValue PrimType
bt Signedness
ept VName
v

setCodeSpace :: Space -> Code op -> Code op
setCodeSpace :: Space -> Code op -> Code op
setCodeSpace Space
space (Allocate VName
v Count Bytes (TExp Int64)
e Space
old_space) =
  VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Allocate VName
v Count Bytes (TExp Int64)
e (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
space (Free VName
v Space
old_space) =
  VName -> Space -> Code op
forall a. VName -> Space -> Code a
Free VName
v (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
space (DeclareMem VName
name Space
old_space) =
  VName -> Space -> Code op
forall a. VName -> Space -> Code a
DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
space (DeclareArray VName
name Space
_ PrimType
t ArrayContents
vs) =
  VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
DeclareArray VName
name Space
space PrimType
t ArrayContents
vs
setCodeSpace Space
space (Copy VName
dest Count Bytes (TExp Int64)
dest_offset Space
dest_space VName
src Count Bytes (TExp Int64)
src_offset Space
src_space Count Bytes (TExp Int64)
n) =
  VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Copy VName
dest Count Bytes (TExp Int64)
dest_offset Space
dest_space' VName
src Count Bytes (TExp Int64)
src_offset Space
src_space' Count Bytes (TExp Int64)
n
  where
    dest_space' :: Space
dest_space' = Space -> Space -> Space
setSpace Space
space Space
dest_space
    src_space' :: Space
src_space' = Space -> Space -> Space
setSpace Space
space Space
src_space
setCodeSpace Space
space (Write VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt Space
dest_space Volatility
vol Exp
e) =
  VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt (Space -> Space -> Space
setSpace Space
space Space
dest_space) Volatility
vol Exp
e
setCodeSpace Space
space (Read VName
x VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt Space
dest_space Volatility
vol) =
  VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code op
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read VName
x VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt (Space -> Space -> Space
setSpace Space
space Space
dest_space) Volatility
vol
setCodeSpace Space
space (Code op
c1 :>>: Code op
c2) =
  Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c1 Code op -> Code op -> Code op
forall a. Code a -> Code a -> Code a
:>>: Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c2
setCodeSpace Space
space (For VName
i Exp
e Code op
body) =
  VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
For VName
i Exp
e (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
body
setCodeSpace Space
space (While TExp Bool
e Code op
body) =
  TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
While TExp Bool
e (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
body
setCodeSpace Space
space (If TExp Bool
e Code op
c1 Code op
c2) =
  TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
If TExp Bool
e (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c1) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c2)
setCodeSpace Space
space (Comment String
s Code op
c) =
  String -> Code op -> Code op
forall a. String -> Code a -> Code a
Comment String
s (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setCodeSpace Space
space Code op
c
setCodeSpace Space
_ Code op
Skip =
  Code op
forall a. Code a
Skip
setCodeSpace Space
_ (DeclareScalar VName
name Volatility
vol PrimType
bt) =
  VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
name Volatility
vol PrimType
bt
setCodeSpace Space
_ (SetScalar VName
name Exp
e) =
  VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
SetScalar VName
name Exp
e
setCodeSpace Space
space (SetMem VName
to VName
from Space
old_space) =
  VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
SetMem VName
to VName
from (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setCodeSpace Space
_ (Call [VName]
dests Name
fname [Arg]
args) =
  [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Call [VName]
dests Name
fname [Arg]
args
setCodeSpace Space
_ (Assert Exp
e ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc) =
  Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Assert Exp
e ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc
setCodeSpace Space
_ (DebugPrint String
s Maybe Exp
v) =
  String -> Maybe Exp -> Code op
forall a. String -> Maybe Exp -> Code a
DebugPrint String
s Maybe Exp
v
setCodeSpace Space
_ (TracePrint ErrorMsg Exp
msg) =
  ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
TracePrint ErrorMsg Exp
msg
setCodeSpace Space
_ (Op op
op) =
  op -> Code op
forall a. a -> Code a
Op op
op

setSpace :: Space -> Space -> Space
setSpace :: Space -> Space -> Space
setSpace Space
space Space
DefaultSpace = Space
space
setSpace Space
_ Space
space = Space
space