{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.GPU.Simplify
  ( simplifyGPU,
    simplifyLambda,
    GPU,

    -- * Building blocks
    simplifyKernelOp,
  )
where

import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.GPU
import qualified Futhark.IR.SOACS.Simplify as SOAC
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify as Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rep
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT

simpleGPU :: Simplify.SimpleOps GPU
simpleGPU :: SimpleOps GPU
simpleGPU = SimplifyOp GPU (Op GPU) -> SimpleOps GPU
forall rep.
(SimplifiableRep rep, Buildable rep) =>
SimplifyOp rep (Op rep) -> SimpleOps rep
Simplify.bindableSimpleOps (SimplifyOp GPU (Op GPU) -> SimpleOps GPU)
-> SimplifyOp GPU (Op GPU) -> SimpleOps GPU
forall a b. (a -> b) -> a -> b
$ SimplifyOp GPU (SOAC GPU)
-> HostOp GPU (SOAC GPU)
-> SimpleM
     GPU (HostOp (Wise GPU) (OpWithWisdom (SOAC GPU)), Stms (Wise GPU))
forall rep op.
(SimplifiableRep rep, BodyDec rep ~ ()) =>
SimplifyOp rep op
-> HostOp rep op
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
simplifyKernelOp SimplifyOp GPU (SOAC GPU)
forall rep. SimplifiableRep rep => SimplifyOp rep (SOAC rep)
SOAC.simplifySOAC

simplifyGPU :: Prog GPU -> PassM (Prog GPU)
simplifyGPU :: Prog GPU -> PassM (Prog GPU)
simplifyGPU =
  SimpleOps GPU
-> RuleBook (Wise GPU)
-> HoistBlockers GPU
-> Prog GPU
-> PassM (Prog GPU)
forall rep.
SimplifiableRep rep =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Prog rep
-> PassM (Prog rep)
Simplify.simplifyProg SimpleOps GPU
simpleGPU RuleBook (Wise GPU)
kernelRules HoistBlockers GPU
forall rep. HoistBlockers rep
Simplify.noExtraHoistBlockers

simplifyLambda ::
  (HasScope GPU m, MonadFreshNames m) =>
  Lambda GPU ->
  m (Lambda GPU)
simplifyLambda :: Lambda GPU -> m (Lambda GPU)
simplifyLambda =
  SimpleOps GPU
-> RuleBook (Wise GPU)
-> HoistBlockers GPU
-> Lambda GPU
-> m (Lambda GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, SimplifiableRep rep) =>
SimpleOps rep
-> RuleBook (Wise rep)
-> HoistBlockers rep
-> Lambda rep
-> m (Lambda rep)
Simplify.simplifyLambda SimpleOps GPU
simpleGPU RuleBook (Wise GPU)
kernelRules HoistBlockers GPU
forall rep. HoistBlockers rep
Engine.noExtraHoistBlockers

simplifyKernelOp ::
  ( Engine.SimplifiableRep rep,
    BodyDec rep ~ ()
  ) =>
  Simplify.SimplifyOp rep op ->
  HostOp rep op ->
  Engine.SimpleM rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
simplifyKernelOp :: SimplifyOp rep op
-> HostOp rep op
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
simplifyKernelOp SimplifyOp rep op
f (OtherOp op
op) = do
  (OpWithWisdom op
op', Stms (Wise rep)
stms) <- SimplifyOp rep op
f op
op
  (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom op -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. op -> HostOp rep op
OtherOp OpWithWisdom op
op', Stms (Wise rep)
stms)
simplifyKernelOp SimplifyOp rep op
_ (SegOp SegOp SegLevel rep
op) = do
  (SegOp SegLevel (Wise rep)
op', Stms (Wise rep)
hoisted) <- SegOp SegLevel rep
-> SimpleM rep (SegOp SegLevel (Wise rep), Stms (Wise rep))
forall rep lvl.
(SimplifiableRep rep, BodyDec rep ~ (), Simplifiable lvl) =>
SegOp lvl rep
-> SimpleM rep (SegOp lvl (Wise rep), Stms (Wise rep))
simplifySegOp SegOp SegLevel rep
op
  (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp SegLevel (Wise rep) -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel (Wise rep)
op', Stms (Wise rep)
hoisted)
simplifyKernelOp SimplifyOp rep op
_ (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread)) =
  (,)
    (HostOp (Wise rep) (OpWithWisdom op)
 -> Stms (Wise rep)
 -> (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep)))
-> SimpleM rep (HostOp (Wise rep) (OpWithWisdom op))
-> SimpleM
     rep
     (Stms (Wise rep)
      -> (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. SizeOp -> HostOp rep op
SizeOp
            (SizeOp -> HostOp (Wise rep) (OpWithWisdom op))
-> SimpleM rep SizeOp
-> SimpleM rep (HostOp (Wise rep) (OpWithWisdom op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> SimpleM rep SplitOrdering
-> SimpleM rep (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM rep SplitOrdering
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SplitOrdering
o SimpleM rep (SubExp -> SubExp -> SubExp -> SizeOp)
-> SimpleM rep SubExp -> SimpleM rep (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
                    SimpleM rep (SubExp -> SubExp -> SizeOp)
-> SimpleM rep SubExp -> SimpleM rep (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
i
                    SimpleM rep (SubExp -> SizeOp)
-> SimpleM rep SubExp -> SimpleM rep SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
elems_per_thread
                )
        )
    SimpleM
  rep
  (Stms (Wise rep)
   -> (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep)))
-> SimpleM rep (Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise rep) -> SimpleM rep (Stms (Wise rep))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise rep)
forall a. Monoid a => a
mempty
simplifyKernelOp SimplifyOp rep op
_ (SizeOp (GetSize Name
key SizeClass
size_class)) =
  (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp (Wise rep) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
key SizeClass
size_class, Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep op
_ (SizeOp (GetSizeMax SizeClass
size_class)) =
  (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp (Wise rep) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
size_class, Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep op
_ (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x)) = do
  SubExp
x' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
x
  (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp (Wise rep) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
key SizeClass
size_class SubExp
x', Stms (Wise rep)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp rep op
_ (SizeOp (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size)) = do
  SubExp
w' <- SubExp -> SimpleM rep SubExp
forall e rep.
(Simplifiable e, SimplifiableRep rep) =>
e -> SimpleM rep e
Engine.simplify SubExp
w
  (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
-> SimpleM
     rep (HostOp (Wise rep) (OpWithWisdom op), Stms (Wise rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp (Wise rep) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise rep) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w' Name
max_num_groups SubExp
group_size, Stms (Wise rep)
forall a. Monoid a => a
mempty)

instance TraverseOpStms (Wise GPU) where
  traverseOpStms :: OpStmsTraverser m (Op (Wise GPU)) (Wise GPU)
traverseOpStms = OpStmsTraverser m (SOAC (Wise GPU)) (Wise GPU)
-> OpStmsTraverser
     m (HostOp (Wise GPU) (SOAC (Wise GPU))) (Wise GPU)
forall (m :: * -> *) op rep.
Monad m =>
OpStmsTraverser m op rep -> OpStmsTraverser m (HostOp rep op) rep
traverseHostOpStms OpStmsTraverser m (SOAC (Wise GPU)) (Wise GPU)
forall (m :: * -> *) rep.
Monad m =>
OpStmsTraverser m (SOAC rep) rep
traverseSOACStms

instance BuilderOps (Wise GPU)

instance HasSegOp (Wise GPU) where
  type SegOpLevel (Wise GPU) = SegLevel
  asSegOp :: Op (Wise GPU) -> Maybe (SegOp (SegOpLevel (Wise GPU)) (Wise GPU))
asSegOp (SegOp op) = SegOp SegLevel (Wise GPU) -> Maybe (SegOp SegLevel (Wise GPU))
forall a. a -> Maybe a
Just SegOp SegLevel (Wise GPU)
op
  asSegOp Op (Wise GPU)
_ = Maybe (SegOp (SegOpLevel (Wise GPU)) (Wise GPU))
forall a. Maybe a
Nothing
  segOp :: SegOp (SegOpLevel (Wise GPU)) (Wise GPU) -> Op (Wise GPU)
segOp = SegOp (SegOpLevel (Wise GPU)) (Wise GPU) -> Op (Wise GPU)
forall rep op. SegOp SegLevel rep -> HostOp rep op
SegOp

instance SOAC.HasSOAC (Wise GPU) where
  asSOAC :: Op (Wise GPU) -> Maybe (SOAC (Wise GPU))
asSOAC (OtherOp soac) = SOAC (Wise GPU) -> Maybe (SOAC (Wise GPU))
forall a. a -> Maybe a
Just SOAC (Wise GPU)
soac
  asSOAC Op (Wise GPU)
_ = Maybe (SOAC (Wise GPU))
forall a. Maybe a
Nothing
  soacOp :: SOAC (Wise GPU) -> Op (Wise GPU)
soacOp = SOAC (Wise GPU) -> Op (Wise GPU)
forall rep op. op -> HostOp rep op
OtherOp

kernelRules :: RuleBook (Wise GPU)
kernelRules :: RuleBook (Wise GPU)
kernelRules =
  RuleBook (Wise GPU)
forall rep.
(BuilderOps rep, TraverseOpStms rep, Aliased rep) =>
RuleBook rep
standardRules RuleBook (Wise GPU) -> RuleBook (Wise GPU) -> RuleBook (Wise GPU)
forall a. Semigroup a => a -> a -> a
<> RuleBook (Wise GPU)
forall rep.
(HasSegOp rep, BuilderOps rep, Buildable rep) =>
RuleBook rep
segOpRules
    RuleBook (Wise GPU) -> RuleBook (Wise GPU) -> RuleBook (Wise GPU)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise GPU)]
-> [BottomUpRule (Wise GPU)] -> RuleBook (Wise GPU)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
redomapIotaToLoop,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.simplifyKnownIterationSOAC,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Aliased rep, Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.removeReplicateMapping,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.liftIdentityMapping,
        RuleOp (Wise GPU) (TopDown (Wise GPU)) -> TopDownRule (Wise GPU)
forall rep a. RuleOp rep a -> SimplificationRule rep a
RuleOp RuleOp (Wise GPU) (TopDown (Wise GPU))
forall rep.
(Buildable rep, BuilderOps rep, HasSOAC rep) =>
TopDownRuleOp rep
SOAC.simplifyMapIota
      ]
      [ RuleBasicOp (Wise GPU) (BottomUp (Wise GPU))
-> BottomUpRule (Wise GPU)
forall rep a. RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp RuleBasicOp (Wise GPU) (BottomUp (Wise GPU))
forall rep.
(BuilderOps rep, Aliased rep) =>
BottomUpRuleBasicOp rep
removeUnnecessaryCopy
      ]

-- We turn reductions over (solely) iotas into do-loops, because there
-- is no useful structure here anyway.  This is mostly a hack to work
-- around the fact that loop tiling would otherwise pointlessly tile
-- them.
redomapIotaToLoop :: TopDownRuleOp (Wise GPU)
redomapIotaToLoop :: RuleOp (Wise GPU) (TopDown (Wise GPU))
redomapIotaToLoop TopDown (Wise GPU)
vtable Pat (Wise GPU)
pat StmAux (ExpDec (Wise GPU))
aux (OtherOp soac@(Screma _ [arr] form))
  | Just ([Reduce (Wise GPU)], Lambda (Wise GPU))
_ <- ScremaForm (Wise GPU)
-> Maybe ([Reduce (Wise GPU)], Lambda (Wise GPU))
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm (Wise GPU)
form,
    Just (Iota {}, Certs
_) <- VName -> TopDown (Wise GPU) -> Maybe (BasicOp, Certs)
forall rep. VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
arr TopDown (Wise GPU)
vtable =
    RuleM (Wise GPU) () -> Rule (Wise GPU)
forall rep. RuleM rep () -> Rule rep
Simplify (RuleM (Wise GPU) () -> Rule (Wise GPU))
-> RuleM (Wise GPU) () -> Rule (Wise GPU)
forall a b. (a -> b) -> a -> b
$ Certs -> RuleM (Wise GPU) () -> RuleM (Wise GPU) ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise GPU))
aux) (RuleM (Wise GPU) () -> RuleM (Wise GPU) ())
-> RuleM (Wise GPU) () -> RuleM (Wise GPU) ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (RuleM (Wise GPU)))
-> SOAC (Rep (RuleM (Wise GPU))) -> RuleM (Wise GPU) ()
forall (m :: * -> *).
Transformer m =>
Pat (Rep m) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (Rep (RuleM (Wise GPU)))
Pat (Wise GPU)
pat SOAC (Rep (RuleM (Wise GPU)))
SOAC (Wise GPU)
soac
redomapIotaToLoop TopDown (Wise GPU)
_ Pat (Wise GPU)
_ StmAux (ExpDec (Wise GPU))
_ Op (Wise GPU)
_ =
  Rule (Wise GPU)
forall rep. Rule rep
Skip