-- | Change 'DefaultSpace' in a program to some other memory space.
-- This is needed because the GPU backends use 'DefaultSpace' to refer
-- to GPU memory for most of the pipeline, but final code generation
-- assumes that 'DefaultSpace' is CPU memory.
module Futhark.CodeGen.SetDefaultSpace
  ( setDefaultSpace,
  )
where

import Futhark.CodeGen.ImpCode

-- | Set all uses of 'DefaultSpace' in the given definitions to another
-- memory space.
setDefaultSpace :: Space -> Definitions op -> Definitions op
setDefaultSpace :: forall op. Space -> Definitions op -> Definitions op
setDefaultSpace Space
space (Definitions (Constants [Param]
ps Code op
consts) (Functions [(Name, Function op)]
fundecs)) =
  Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Definitions
    ([Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Constants ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
ps) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
consts))
    ( [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions
        [ (Name
fname, Space -> Function op -> Function op
forall op. Space -> Function op -> Function op
setFunctionSpace Space
space Function op
func)
          | (Name
fname, Function op
func) <- [(Name, Function op)]
fundecs
        ]
    )

setFunctionSpace :: Space -> Function op -> Function op
setFunctionSpace :: forall op. Space -> Function op -> Function op
setFunctionSpace Space
space (Function Bool
entry [Param]
outputs [Param]
inputs Code op
body [ExternalValue]
results [ExternalValue]
args) =
  Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Function
    Bool
entry
    ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
outputs)
    ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
inputs)
    (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
body)
    ((ExternalValue -> ExternalValue)
-> [ExternalValue] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [ExternalValue]
results)
    ((ExternalValue -> ExternalValue)
-> [ExternalValue] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [ExternalValue]
args)

setParamSpace :: Space -> Param -> Param
setParamSpace :: Space -> Param -> Param
setParamSpace Space
space (MemParam VName
name Space
DefaultSpace) =
  VName -> Space -> Param
MemParam VName
name Space
space
setParamSpace Space
_ Param
param =
  Param
param

setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space (OpaqueValue String
desc [ValueDesc]
vs) =
  String -> [ValueDesc] -> ExternalValue
OpaqueValue String
desc ([ValueDesc] -> ExternalValue) -> [ValueDesc] -> ExternalValue
forall a b. (a -> b) -> a -> b
$ (ValueDesc -> ValueDesc) -> [ValueDesc] -> [ValueDesc]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ValueDesc -> ValueDesc
setValueSpace Space
space) [ValueDesc]
vs
setExtValueSpace Space
space (TransparentValue ValueDesc
v) =
  ValueDesc -> ExternalValue
TransparentValue (ValueDesc -> ExternalValue) -> ValueDesc -> ExternalValue
forall a b. (a -> b) -> a -> b
$ Space -> ValueDesc -> ValueDesc
setValueSpace Space
space ValueDesc
v

setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace Space
space (ArrayValue VName
mem Space
_ PrimType
bt Signedness
ept [DimSize]
shape) =
  VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
ArrayValue VName
mem Space
space PrimType
bt Signedness
ept [DimSize]
shape
setValueSpace Space
_ (ScalarValue PrimType
bt Signedness
ept VName
v) =
  PrimType -> Signedness -> VName -> ValueDesc
ScalarValue PrimType
bt Signedness
ept VName
v

setBodySpace :: Space -> Code op -> Code op
setBodySpace :: forall op. Space -> Code op -> Code op
setBodySpace Space
space (Allocate VName
v Count Bytes (TExp Int64)
e Space
old_space) =
  VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Allocate VName
v ((TExp Int64 -> TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> TExp Int64 -> TExp Int64
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space) Count Bytes (TExp Int64)
e) (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (Free VName
v Space
old_space) =
  VName -> Space -> Code op
forall a. VName -> Space -> Code a
Free VName
v (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (DeclareMem VName
name Space
old_space) =
  VName -> Space -> Code op
forall a. VName -> Space -> Code a
DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (DeclareArray VName
name Space
_ PrimType
t ArrayContents
vs) =
  VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
DeclareArray VName
name Space
space PrimType
t ArrayContents
vs
setBodySpace Space
space (Copy VName
dest Count Bytes (TExp Int64)
dest_offset Space
dest_space VName
src Count Bytes (TExp Int64)
src_offset Space
src_space Count Bytes (TExp Int64)
n) =
  VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Copy
    VName
dest
    ((TExp Int64 -> TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> TExp Int64 -> TExp Int64
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space) Count Bytes (TExp Int64)
dest_offset)
    Space
dest_space'
    VName
src
    ((TExp Int64 -> TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> TExp Int64 -> TExp Int64
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space) Count Bytes (TExp Int64)
src_offset)
    Space
src_space'
    (Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64)
-> Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> TExp Int64 -> TExp Int64
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space) Count Bytes (TExp Int64)
n
  where
    dest_space' :: Space
dest_space' = Space -> Space -> Space
setSpace Space
space Space
dest_space
    src_space' :: Space
src_space' = Space -> Space -> Space
setSpace Space
space Space
src_space
setBodySpace Space
space (Write VName
dest Count Elements (TExp Int64)
dest_offset PrimType
bt Space
dest_space Volatility
vol Exp
e) =
  VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write
    VName
dest
    ((TExp Int64 -> TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> TExp Int64 -> TExp Int64
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space) Count Elements (TExp Int64)
dest_offset)
    PrimType
bt
    (Space -> Space -> Space
setSpace Space
space Space
dest_space)
    Volatility
vol
    (Space -> Exp -> Exp
setExpSpace Space
space Exp
e)
setBodySpace Space
space (Code op
c1 :>>: Code op
c2) =
  Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c1 Code op -> Code op -> Code op
forall a. Code a -> Code a -> Code a
:>>: Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c2
setBodySpace Space
space (For VName
i Exp
e Code op
body) =
  VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
For VName
i (Space -> Exp -> Exp
setExpSpace Space
space Exp
e) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
body
setBodySpace Space
space (While TExp Bool
e Code op
body) =
  TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
While (Space -> TExp Bool -> TExp Bool
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space TExp Bool
e) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
body
setBodySpace Space
space (If TExp Bool
e Code op
c1 Code op
c2) =
  TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
If (Space -> TExp Bool -> TExp Bool
forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space TExp Bool
e) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c1) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c2)
setBodySpace Space
space (Comment String
s Code op
c) =
  String -> Code op -> Code op
forall a. String -> Code a -> Code a
Comment String
s (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c
setBodySpace Space
_ Code op
Skip =
  Code op
forall a. Code a
Skip
setBodySpace Space
_ (DeclareScalar VName
name Volatility
vol PrimType
bt) =
  VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
name Volatility
vol PrimType
bt
setBodySpace Space
space (SetScalar VName
name Exp
e) =
  VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Exp -> Exp
setExpSpace Space
space Exp
e
setBodySpace Space
space (SetMem VName
to VName
from Space
old_space) =
  VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
SetMem VName
to VName
from (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (Call [VName]
dests Name
fname [Arg]
args) =
  [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Call [VName]
dests Name
fname ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
setArgSpace [Arg]
args
  where
    setArgSpace :: Arg -> Arg
setArgSpace (MemArg VName
m) = VName -> Arg
MemArg VName
m
    setArgSpace (ExpArg Exp
e) = Exp -> Arg
ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ Space -> Exp -> Exp
setExpSpace Space
space Exp
e
setBodySpace Space
space (Assert Exp
e ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc) =
  Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Assert (Space -> Exp -> Exp
setExpSpace Space
space Exp
e) ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc
setBodySpace Space
space (DebugPrint String
s Maybe Exp
v) =
  String -> Maybe Exp -> Code op
forall a. String -> Maybe Exp -> Code a
DebugPrint String
s (Maybe Exp -> Code op) -> Maybe Exp -> Code op
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp) -> Maybe Exp -> Maybe Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Maybe Exp
v
setBodySpace Space
_ (Op op
op) =
  op -> Code op
forall a. a -> Code a
Op op
op

setExpSpace :: Space -> Exp -> Exp
setExpSpace :: Space -> Exp -> Exp
setExpSpace Space
space = (ExpLeaf -> ExpLeaf) -> Exp -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ExpLeaf -> ExpLeaf
setLeafSpace
  where
    setLeafSpace :: ExpLeaf -> ExpLeaf
setLeafSpace (Index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
DefaultSpace Volatility
vol) =
      VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> ExpLeaf
Index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
    setLeafSpace ExpLeaf
e = ExpLeaf
e

setTExpSpace :: Space -> TExp t -> TExp t
setTExpSpace :: forall t. Space -> TExp t -> TExp t
setTExpSpace Space
space = Exp -> TPrimExp t ExpLeaf
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TPrimExp t ExpLeaf)
-> (TPrimExp t ExpLeaf -> Exp)
-> TPrimExp t ExpLeaf
-> TPrimExp t ExpLeaf
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> Exp -> Exp
setExpSpace Space
space (Exp -> Exp)
-> (TPrimExp t ExpLeaf -> Exp) -> TPrimExp t ExpLeaf -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp t ExpLeaf -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped

setSpace :: Space -> Space -> Space
setSpace :: Space -> Space -> Space
setSpace Space
space Space
DefaultSpace = Space
space
setSpace Space
_ Space
space = Space
space