{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.GPU.Simplify
  ( simplifyGPU,
    simplifyLambda,
    GPU,

    -- * Building blocks
    simplifyKernelOp,
  )
where

import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Analysis.UsageTable qualified as UT
import Futhark.IR.GPU
import Futhark.IR.SOACS.Simplify qualified as SOAC
import Futhark.MonadFreshNames
import Futhark.Optimise.Simplify qualified as Simplify
import Futhark.Optimise.Simplify.Engine qualified as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Tools
import Futhark.Util (focusNth)

simpleGPU :: Simplify.SimpleOps GPU
simpleGPU :: SimpleOps GPU
simpleGPU = forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op (Wise rep)) -> SimpleOps rep
Simplify.bindableSimpleOps forall a b. (a -> b) -> a -> b
$ forall rep (op :: * -> *).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep (op (Wise rep))
-> HostOp op (Wise rep)
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC (Wise rep))
SOAC.simplifySOAC

simplifyGPU :: Prog GPU -> PassM (Prog GPU)
simplifyGPU :: Prog GPU -> PassM (Prog GPU)
simplifyGPU =
  forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps GPU
simpleGPU RuleBook (Wise GPU)
kernelRules forall {k} (rep :: k). HoistBlockers rep
Simplify.noExtraHoistBlockers

simplifyLambda ::
  (HasScope GPU m, MonadFreshNames m) =>
  Lambda GPU ->
  m (Lambda GPU)
simplifyLambda :: forall (m :: * -> *).
(HasScope GPU m, MonadFreshNames m) =>
Lambda GPU -> m (Lambda GPU)
simplifyLambda =
  forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps GPU
simpleGPU RuleBook (Wise GPU)
kernelRules forall {k} (rep :: k). HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyKernelOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ ()
  ) =>
  Simplify.SimplifyOp rep (op (Wise rep)) ->
  HostOp op (Wise rep) ->
  Engine.SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp :: forall rep (op :: * -> *).
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep (op (Wise rep))
-> HostOp op (Wise rep)
-> SimpleM rep (HostOp op (Wise rep), Stms (Wise rep))
simplifyKernelOp SimplifyOp rep (op (Wise rep))
f (OtherOp op (Wise rep)
op) = do
  (op (Wise rep)
op', Stms (Wise rep)
stms) <- SimplifyOp rep (op (Wise rep))
f op (Wise rep)
op
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp op (Wise rep)
op', Stms (Wise rep)
stms)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SegOp SegOp SegLevel (Wise rep)
op) = do
  (SegOp SegLevel (Wise rep)
op', Stms (Wise rep)
hoisted) <- forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl (Wise rep)
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp SegOp SegLevel (Wise rep)
op
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp SegOp SegLevel (Wise rep)
op', Stms (Wise rep)
hoisted)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (GetSize Name
key SizeClass
size_class)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
key SizeClass
size_class, forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (GetSizeMax SizeClass
size_class)) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
size_class, forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x)) = do
  SubExp
x' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
x
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
key SizeClass
size_class SubExp
x', forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (SizeOp (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size)) = do
  SubExp
w' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w' Name
max_num_groups SubExp
group_size, forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep (op (Wise rep))
_ (GPUBody [Type]
ts Body (Wise rep)
body) = do
  [Type]
ts' <- forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify [Type]
ts
  (Stms (Wise rep)
hoisted, Body (Wise rep)
body') <-
    forall rep.
SimplifiableRep rep =>
BlockPred (Wise rep)
-> UsageTable
-> [Usages]
-> Body (Wise rep)
-> SimpleM rep (Stms (Wise rep), Body (Wise rep))
Engine.simplifyBody forall {p} {p} {rep}. p -> p -> Stm rep -> Bool
keepOnGPU forall a. Monoid a => a
mempty (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const forall a. Monoid a => a
mempty) [Type]
ts) Body (Wise rep)
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
ts' Body (Wise rep)
body', Stms (Wise rep)
hoisted)
  where
    keepOnGPU :: p -> p -> Stm rep -> Bool
keepOnGPU p
_ p
_ = forall {rep}. Exp rep -> Bool
keepExpOnGPU forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stm rep -> Exp rep
stmExp
    keepExpOnGPU :: Exp rep -> Bool
keepExpOnGPU (BasicOp Index {}) = Bool
True
    keepExpOnGPU (BasicOp (ArrayLit [SubExp]
_ Type
t)) | forall shape u. TypeBase shape u -> Bool
primType Type
t = Bool
True
    keepExpOnGPU DoLoop {} = Bool
True
    keepExpOnGPU Exp rep
_ = Bool
False

instance TraverseOpStms (Wise GPU) where
  traverseOpStms :: forall (m :: * -> *).
Monad m =>
OpStmsTraverser m (Op (Wise GPU)) (Wise GPU)
traverseOpStms = forall (m :: * -> *) (op :: * -> *) rep.
Monad m =>
OpStmsTraverser m (op rep) rep
-> OpStmsTraverser m (HostOp op rep) rep
traverseHostOpStms forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms

instance BuilderOps (Wise GPU)

instance HasSegOp (Wise GPU) where
  type SegOpLevel (Wise GPU) = SegLevel
  asSegOp :: Op (Wise GPU) -> Maybe (SegOp (SegOpLevel (Wise GPU)) (Wise GPU))
asSegOp (SegOp SegOp SegLevel (Wise GPU)
op) = forall a. a -> Maybe a
Just SegOp SegLevel (Wise GPU)
op
  asSegOp Op (Wise GPU)
_ = forall a. Maybe a
Nothing
  segOp :: SegOp (SegOpLevel (Wise GPU)) (Wise GPU) -> Op (Wise GPU)
segOp = forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp

instance SOAC.HasSOAC (Wise GPU) where
  asSOAC :: Op (Wise GPU) -> Maybe (SOAC (Wise GPU))
asSOAC (OtherOp SOAC (Wise GPU)
soac) = forall a. a -> Maybe a
Just SOAC (Wise GPU)
soac
  asSOAC Op (Wise GPU)
_ = forall a. Maybe a
Nothing
  soacOp :: SOAC (Wise GPU) -> Op (Wise GPU)
soacOp = forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp

kernelRules :: RuleBook (Wise GPU)
kernelRules :: RuleBook (Wise GPU)
kernelRules =
  forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules
    forall a. Semigroup a => a -> a -> a
<> forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep, Aliased rep) =>
RuleBook rep
segOpRules
    forall a. Semigroup a => a -> a -> a
<> forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.simplifyKnownIterationSOAC,
        forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Aliased rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.removeReplicateMapping,
        forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.liftIdentityMapping,
        forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.simplifyMapIota,
        forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.removeUnusedSOACInput,
        forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => TopDownRuleBasicOp rep
removeScalarCopy
      ]
      [ forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
removeUnnecessaryCopy,
        forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp BottomUpRuleOp (Wise GPU)
removeDeadGPUBodyResult
      ]

-- | Remove the unused return values of a GPUBody.
removeDeadGPUBodyResult :: BottomUpRuleOp (Wise GPU)
removeDeadGPUBodyResult :: BottomUpRuleOp (Wise GPU)
removeDeadGPUBodyResult (SymbolTable (Wise GPU)
_, UsageTable
used) Pat (LetDec (Wise GPU))
pat StmAux (ExpDec (Wise GPU))
aux (GPUBody [Type]
types Body (Wise GPU)
body)
  | -- Figure out which of the names in 'pat' are used...
    [Bool]
pat_used <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> UsageTable -> Bool
`UT.isUsedDirectly` UsageTable
used) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec (Wise GPU))
pat,
    -- If they are not all used, then this rule applies.
    Bool -> Bool
not (forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
pat_used) =
      -- Remove the parts of the GPUBody results that correspond to dead
      -- return value bindings.  Note that this leaves dead code in the
      -- kernel, but that will be removed later.
      let pick :: [a] -> [a]
          pick :: forall a. [a] -> [a]
pick = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
pat_used
          pat' :: [PatElem (VarWisdom, Type)]
pat' = forall a. [a] -> [a]
pick (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec (Wise GPU))
pat)
          types' :: [Type]
types' = forall a. [a] -> [a]
pick [Type]
types
          body' :: Body (Wise GPU)
body' = Body (Wise GPU)
body {bodyResult :: Result
bodyResult = forall a. [a] -> [a]
pick (forall rep. Body rep -> Result
bodyResult Body (Wise GPU)
body)}
       in forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec (Wise GPU))
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem (VarWisdom, Type)]
pat') forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. [Type] -> Body rep -> HostOp op rep
GPUBody [Type]
types' Body (Wise GPU)
body'
  | Bool
otherwise = forall rep. Rule rep
Skip
removeDeadGPUBodyResult (SymbolTable (Wise GPU), UsageTable)
_ Pat (LetDec (Wise GPU))
_ StmAux (ExpDec (Wise GPU))
_ Op (Wise GPU)
_ = forall rep. Rule rep
Skip

-- If we see an Update with a scalar where the value to be written is
-- the result of indexing some other array, then we convert it into an
-- Update with a slice of that array.  This matters when the arrays
-- are far away (on the GPU, say), because it avoids a copy of the
-- scalar to and from the host.
removeScalarCopy :: BuilderOps rep => TopDownRuleBasicOp rep
removeScalarCopy :: forall rep. BuilderOps rep => TopDownRuleBasicOp rep
removeScalarCopy TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
safety VName
arr_x (Slice [DimIndex SubExp]
slice_x) (Var VName
v))
  | Just [SubExp]
_ <- forall d. Slice d -> Maybe [d]
sliceIndices (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
slice_x),
    Just (Index VName
arr_y (Slice [DimIndex SubExp]
slice_y), Certs
cs_y) <- forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v TopDown rep
vtable,
    forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
arr_y TopDown rep
vtable,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall rep. VName -> VName -> SymbolTable rep -> Bool
ST.aliases VName
arr_x VName
arr_y TopDown rep
vtable,
    Just ([DimIndex SubExp]
slice_x_bef, DimFix SubExp
i, []) <- forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice_x forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
slice_x,
    Just ([DimIndex SubExp]
slice_y_bef, DimFix SubExp
j, []) <- forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice_y forall a. Num a => a -> a -> a
- Int
1) [DimIndex SubExp]
slice_y = forall rep. RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let slice_x' :: Slice SubExp
slice_x' = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice_x_bef forall a. [a] -> [a] -> [a]
++ [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
i (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
          slice_y' :: Slice SubExp
slice_y' = forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice_y_bef forall a. [a] -> [a] -> [a]
++ [forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
      VName
v' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr_y Slice SubExp
slice_y'
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs_y forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
          forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
arr_x Slice SubExp
slice_x' forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
v'
removeScalarCopy TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ =
  forall rep. Rule rep
Skip