{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.ToGPU
  ( getSize,
    segThread,
    soacsLambdaToGPU,
    soacsStmToGPU,
    scopeForGPU,
    scopeForSOACs,
    injectSOACS,
  )
where

import Control.Monad.Identity
import Data.List ()
import Futhark.IR
import Futhark.IR.GPU
import Futhark.IR.SOACS (SOACS)
import Futhark.IR.SOACS.SOAC qualified as SOAC
import Futhark.Tools

getSize ::
  (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
  String ->
  SizeClass ->
  m SubExp
getSize :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
  Name
size_key <- String -> Name
nameFromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> String
prettyString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
desc forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

segThread ::
  (MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
  String ->
  m SegLevel
segThread :: forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> m SegLevel
segThread String
desc =
  SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegVirt forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m KernelGrid
kernelGrid)
  where
    kernelGrid :: m KernelGrid
kernelGrid =
      Count NumGroups SubExp -> Count GroupSize SubExp -> KernelGrid
KernelGrid
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (u :: k) e. e -> Count u e
Count forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SizeClass -> m SubExp
getSize (String
desc forall a. [a] -> [a] -> [a]
++ String
"_num_groups") SizeClass
SizeNumGroups)
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall {k} (u :: k) e. e -> Count u e
Count forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) (inner :: * -> *).
(MonadBuilder m, Op (Rep m) ~ HostOp inner (Rep m)) =>
String -> SizeClass -> m SubExp
getSize (String
desc forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup)

injectSOACS ::
  ( Monad m,
    SameScope from to,
    ExpDec from ~ ExpDec to,
    BodyDec from ~ BodyDec to,
    RetType from ~ RetType to,
    BranchType from ~ BranchType to,
    Op from ~ SOAC from
  ) =>
  (SOAC to -> Op to) ->
  Rephraser m from to
injectSOACS :: forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f =
  Rephraser
    { rephraseExpDec :: ExpDec from -> m (ExpDec to)
rephraseExpDec = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      rephraseBodyDec :: BodyDec from -> m (BodyDec to)
rephraseBodyDec = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      rephraseLetBoundDec :: LetDec from -> m (LetDec to)
rephraseLetBoundDec = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      rephraseFParamDec :: FParamInfo from -> m (FParamInfo to)
rephraseFParamDec = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      rephraseLParamDec :: LParamInfo from -> m (LParamInfo to)
rephraseLParamDec = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      rephraseOp :: Op from -> m (Op to)
rephraseOp = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SOAC to -> Op to
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC from -> m (SOAC to)
onSOAC,
      rephraseRetType :: RetType from -> m (RetType to)
rephraseRetType = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
      rephraseBranchType :: BranchType from -> m (BranchType to)
rephraseBranchType = forall (f :: * -> *) a. Applicative f => a -> f a
pure
    }
  where
    onSOAC :: SOAC from -> m (SOAC to)
onSOAC = forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
SOAC.mapSOACM SOACMapper from to m
mapper
    mapper :: SOACMapper from to m
mapper =
      SOAC.SOACMapper
        { mapOnSOACSubExp :: SubExp -> m SubExp
SOAC.mapOnSOACSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnSOACVName :: VName -> m VName
SOAC.mapOnSOACVName = forall (f :: * -> *) a. Applicative f => a -> f a
pure,
          mapOnSOACLambda :: Lambda from -> m (Lambda to)
SOAC.mapOnSOACLambda = forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f
        }

soacsStmToGPU :: Stm SOACS -> Stm GPU
soacsStmToGPU :: Stm SOACS -> Stm GPU
soacsStmToGPU = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm (forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp)

soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU
soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU
soacsLambdaToGPU = forall a. Identity a -> a
runIdentity forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda (forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS forall (op :: * -> *) rep. op rep -> HostOp op rep
OtherOp)

scopeForSOACs :: Scope GPU -> Scope SOACS
scopeForSOACs :: Scope GPU -> Scope SOACS
scopeForSOACs = forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope

scopeForGPU :: Scope SOACS -> Scope GPU
scopeForGPU :: Scope SOACS -> Scope GPU
scopeForGPU = forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope