-- | Change 'DefaultSpace' in a program to some other memory space.
-- This is needed because the GPU backends use 'DefaultSpace' to refer
-- to GPU memory for most of the pipeline, but final code generation
-- assumes that 'DefaultSpace' is CPU memory.
module Futhark.CodeGen.SetDefaultSpace
       ( setDefaultSpace
       )
       where

import Futhark.CodeGen.ImpCode

-- | Set all uses of 'DefaultSpace' in the given definitions to another
-- memory space.
setDefaultSpace :: Space -> Definitions op -> Definitions op
setDefaultSpace :: Space -> Definitions op -> Definitions op
setDefaultSpace Space
space (Definitions (Constants [Param]
ps Code op
consts) (Functions [(Name, Function op)]
fundecs)) =
  Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Definitions
  ([Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Constants ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
ps) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
consts))
  ([(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions [ (Name
fname, Space -> Function op -> Function op
forall op. Space -> Function op -> Function op
setFunctionSpace Space
space Function op
func)
             | (Name
fname, Function op
func) <- [(Name, Function op)]
fundecs ])

setFunctionSpace :: Space -> Function op -> Function op
setFunctionSpace :: Space -> Function op -> Function op
setFunctionSpace Space
space (Function Bool
entry [Param]
outputs [Param]
inputs Code op
body [ExternalValue]
results [ExternalValue]
args) =
  Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Function Bool
entry
  ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
outputs)
  ((Param -> Param) -> [Param] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> Param -> Param
setParamSpace Space
space) [Param]
inputs)
  (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
body)
  ((ExternalValue -> ExternalValue)
-> [ExternalValue] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [ExternalValue]
results)
  ((ExternalValue -> ExternalValue)
-> [ExternalValue] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space) [ExternalValue]
args)

setParamSpace :: Space -> Param -> Param
setParamSpace :: Space -> Param -> Param
setParamSpace Space
space (MemParam VName
name Space
DefaultSpace) =
  VName -> Space -> Param
MemParam VName
name Space
space
setParamSpace Space
_ Param
param =
  Param
param

setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace :: Space -> ExternalValue -> ExternalValue
setExtValueSpace Space
space (OpaqueValue String
desc [ValueDesc]
vs) =
  String -> [ValueDesc] -> ExternalValue
OpaqueValue String
desc ([ValueDesc] -> ExternalValue) -> [ValueDesc] -> ExternalValue
forall a b. (a -> b) -> a -> b
$ (ValueDesc -> ValueDesc) -> [ValueDesc] -> [ValueDesc]
forall a b. (a -> b) -> [a] -> [b]
map (Space -> ValueDesc -> ValueDesc
setValueSpace Space
space) [ValueDesc]
vs
setExtValueSpace Space
space (TransparentValue ValueDesc
v) =
  ValueDesc -> ExternalValue
TransparentValue (ValueDesc -> ExternalValue) -> ValueDesc -> ExternalValue
forall a b. (a -> b) -> a -> b
$ Space -> ValueDesc -> ValueDesc
setValueSpace Space
space ValueDesc
v

setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace :: Space -> ValueDesc -> ValueDesc
setValueSpace Space
space (ArrayValue VName
mem Space
_ PrimType
bt Signedness
ept [DimSize]
shape) =
  VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
ArrayValue VName
mem Space
space PrimType
bt Signedness
ept [DimSize]
shape
setValueSpace Space
_ (ScalarValue PrimType
bt Signedness
ept VName
v) =
  PrimType -> Signedness -> VName -> ValueDesc
ScalarValue PrimType
bt Signedness
ept VName
v

setBodySpace :: Space -> Code op -> Code op
setBodySpace :: Space -> Code op -> Code op
setBodySpace Space
space (Allocate VName
v Count Bytes Exp
e Space
old_space) =
  VName -> Count Bytes Exp -> Space -> Code op
forall a. VName -> Count Bytes Exp -> Space -> Code a
Allocate VName
v ((Exp -> Exp) -> Count Bytes Exp -> Count Bytes Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Count Bytes Exp
e) (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (Free VName
v Space
old_space) =
  VName -> Space -> Code op
forall a. VName -> Space -> Code a
Free VName
v (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (DeclareMem VName
name Space
old_space) =
  VName -> Space -> Code op
forall a. VName -> Space -> Code a
DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (DeclareArray VName
name Space
_ PrimType
t ArrayContents
vs) =
  VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
DeclareArray VName
name Space
space PrimType
t ArrayContents
vs
setBodySpace Space
space (Copy VName
dest Count Bytes Exp
dest_offset Space
dest_space VName
src Count Bytes Exp
src_offset Space
src_space Count Bytes Exp
n) =
  VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code op
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Copy
  VName
dest ((Exp -> Exp) -> Count Bytes Exp -> Count Bytes Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Count Bytes Exp
dest_offset) Space
dest_space'
  VName
src ((Exp -> Exp) -> Count Bytes Exp -> Count Bytes Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Count Bytes Exp
src_offset) Space
src_space' (Count Bytes Exp -> Code op) -> Count Bytes Exp -> Code op
forall a b. (a -> b) -> a -> b
$
  (Exp -> Exp) -> Count Bytes Exp -> Count Bytes Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Count Bytes Exp
n
  where dest_space' :: Space
dest_space' = Space -> Space -> Space
setSpace Space
space Space
dest_space
        src_space' :: Space
src_space' = Space -> Space -> Space
setSpace Space
space Space
src_space
setBodySpace Space
space (Write VName
dest Count Elements Exp
dest_offset PrimType
bt Space
dest_space Volatility
vol Exp
e) =
  VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
dest ((Exp -> Exp) -> Count Elements Exp -> Count Elements Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Count Elements Exp
dest_offset) PrimType
bt (Space -> Space -> Space
setSpace Space
space Space
dest_space)
  Volatility
vol (Space -> Exp -> Exp
setExpSpace Space
space Exp
e)
setBodySpace Space
space (Code op
c1 :>>: Code op
c2) =
  Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c1 Code op -> Code op -> Code op
forall a. Code a -> Code a -> Code a
:>>: Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c2
setBodySpace Space
space (For VName
i IntType
it Exp
e Code op
body) =
  VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
For VName
i IntType
it (Space -> Exp -> Exp
setExpSpace Space
space Exp
e) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
body
setBodySpace Space
space (While Exp
e Code op
body) =
  Exp -> Code op -> Code op
forall a. Exp -> Code a -> Code a
While (Space -> Exp -> Exp
setExpSpace Space
space Exp
e) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
body
setBodySpace Space
space (If Exp
e Code op
c1 Code op
c2) =
  Exp -> Code op -> Code op -> Code op
forall a. Exp -> Code a -> Code a -> Code a
If (Space -> Exp -> Exp
setExpSpace Space
space Exp
e) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c1) (Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c2)
setBodySpace Space
space (Comment String
s Code op
c) =
  String -> Code op -> Code op
forall a. String -> Code a -> Code a
Comment String
s (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Code op -> Code op
forall op. Space -> Code op -> Code op
setBodySpace Space
space Code op
c
setBodySpace Space
_ Code op
Skip =
  Code op
forall a. Code a
Skip
setBodySpace Space
_ (DeclareScalar VName
name Volatility
vol PrimType
bt) =
  VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
name Volatility
vol PrimType
bt
setBodySpace Space
space (SetScalar VName
name Exp
e) =
  VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Exp -> Exp
setExpSpace Space
space Exp
e
setBodySpace Space
space (SetMem VName
to VName
from Space
old_space) =
  VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
SetMem VName
to VName
from (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ Space -> Space -> Space
setSpace Space
space Space
old_space
setBodySpace Space
space (Call [VName]
dests Name
fname [Arg]
args) =
  [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Call [VName]
dests Name
fname ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ (Arg -> Arg) -> [Arg] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Arg
setArgSpace [Arg]
args
  where setArgSpace :: Arg -> Arg
setArgSpace (MemArg VName
m) = VName -> Arg
MemArg VName
m
        setArgSpace (ExpArg Exp
e) = Exp -> Arg
ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ Space -> Exp -> Exp
setExpSpace Space
space Exp
e
setBodySpace Space
space (Assert Exp
e ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc) =
  Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Assert (Space -> Exp -> Exp
setExpSpace Space
space Exp
e) ErrorMsg Exp
msg (SrcLoc, [SrcLoc])
loc
setBodySpace Space
space (DebugPrint String
s Maybe Exp
v) =
  String -> Maybe Exp -> Code op
forall a. String -> Maybe Exp -> Code a
DebugPrint String
s (Maybe Exp -> Code op) -> Maybe Exp -> Code op
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp) -> Maybe Exp -> Maybe Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Exp -> Exp
setExpSpace Space
space) Maybe Exp
v
setBodySpace Space
_ (Op op
op) =
  op -> Code op
forall a. a -> Code a
Op op
op

setExpSpace :: Space -> Exp -> Exp
setExpSpace :: Space -> Exp -> Exp
setExpSpace Space
space = (ExpLeaf -> ExpLeaf) -> Exp -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ExpLeaf -> ExpLeaf
setLeafSpace
  where setLeafSpace :: ExpLeaf -> ExpLeaf
setLeafSpace (Index VName
mem Count Elements Exp
i PrimType
bt Space
DefaultSpace Volatility
vol) =
          VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> ExpLeaf
Index VName
mem Count Elements Exp
i PrimType
bt Space
space Volatility
vol
        setLeafSpace ExpLeaf
e = ExpLeaf
e

setSpace :: Space -> Space -> Space
setSpace :: Space -> Space -> Space
setSpace Space
space Space
DefaultSpace = Space
space
setSpace Space
_     Space
space        = Space
space