{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.ToGPU
  ( getSize,
    segThread,
    soacsLambdaToGPU,
    soacsStmToGPU,
    scopeForGPU,
    scopeForSOACs,
    injectSOACS,
  )
where

import Control.Monad.Identity
import Data.List ()
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.GPU
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.SOACS.SOAC as SOAC
import Futhark.Tools

getSize ::
  (MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
  String ->
  SizeClass ->
  m SubExp
getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
  Name
size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
desc (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> ExpT rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Rep m) inner
forall rep op. SizeOp -> HostOp rep op
SizeOp (SizeOp -> HostOp (Rep m) inner) -> SizeOp -> HostOp (Rep m) inner
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

segThread ::
  (MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
  String ->
  m SegLevel
segThread :: String -> m SegLevel
segThread String
desc =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
    (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> m SubExp -> m (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_num_groups") SizeClass
SizeNumGroups)
    m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> m SubExp -> m (Count GroupSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBuilder m, Op (Rep m) ~ HostOp (Rep m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup)
    m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
SegVirt

injectSOACS ::
  ( Monad m,
    SameScope from to,
    ExpDec from ~ ExpDec to,
    BodyDec from ~ BodyDec to,
    RetType from ~ RetType to,
    BranchType from ~ BranchType to,
    Op from ~ SOAC from
  ) =>
  (SOAC to -> Op to) ->
  Rephraser m from to
injectSOACS :: (SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f =
  Rephraser :: forall (m :: * -> *) from to.
(ExpDec from -> m (ExpDec to))
-> (LetDec from -> m (LetDec to))
-> (FParamInfo from -> m (FParamInfo to))
-> (LParamInfo from -> m (LParamInfo to))
-> (BodyDec from -> m (BodyDec to))
-> (RetType from -> m (RetType to))
-> (BranchType from -> m (BranchType to))
-> (Op from -> m (Op to))
-> Rephraser m from to
Rephraser
    { rephraseExpDec :: ExpDec from -> m (ExpDec to)
rephraseExpDec = ExpDec from -> m (ExpDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseBodyDec :: BodyDec from -> m (BodyDec to)
rephraseBodyDec = BodyDec from -> m (BodyDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseLetBoundDec :: LetDec from -> m (LetDec to)
rephraseLetBoundDec = LetDec from -> m (LetDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseFParamDec :: FParamInfo from -> m (FParamInfo to)
rephraseFParamDec = FParamInfo from -> m (FParamInfo to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseLParamDec :: LParamInfo from -> m (LParamInfo to)
rephraseLParamDec = LParamInfo from -> m (LParamInfo to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseOp :: Op from -> m (Op to)
rephraseOp = (SOAC to -> Op to) -> m (SOAC to) -> m (Op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SOAC to -> Op to
f (m (SOAC to) -> m (Op to))
-> (SOAC from -> m (SOAC to)) -> SOAC from -> m (Op to)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC from -> m (SOAC to)
onSOAC,
      rephraseRetType :: RetType from -> m (RetType to)
rephraseRetType = RetType from -> m (RetType to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseBranchType :: BranchType from -> m (BranchType to)
rephraseBranchType = BranchType from -> m (BranchType to)
forall (m :: * -> *) a. Monad m => a -> m a
return
    }
  where
    onSOAC :: SOAC from -> m (SOAC to)
onSOAC = SOACMapper from to m -> SOAC from -> m (SOAC to)
forall (m :: * -> *) frep trep.
(Applicative m, Monad m) =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
SOAC.mapSOACM SOACMapper from to m
mapper
    mapper :: SOACMapper from to m
mapper =
      SOACMapper :: forall frep trep (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda frep -> m (Lambda trep))
-> (VName -> m VName)
-> SOACMapper frep trep m
SOAC.SOACMapper
        { mapOnSOACSubExp :: SubExp -> m SubExp
SOAC.mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnSOACVName :: VName -> m VName
SOAC.mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnSOACLambda :: Lambda from -> m (Lambda to)
SOAC.mapOnSOACLambda = Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda (Rephraser m from to -> Lambda from -> m (Lambda to))
-> Rephraser m from to -> Lambda from -> m (Lambda to)
forall a b. (a -> b) -> a -> b
$ (SOAC to -> Op to) -> Rephraser m from to
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f
        }

soacsStmToGPU :: Stm SOACS -> Stm GPU
soacsStmToGPU :: Stm SOACS -> Stm GPU
soacsStmToGPU = Identity (Stm GPU) -> Stm GPU
forall a. Identity a -> a
runIdentity (Identity (Stm GPU) -> Stm GPU)
-> (Stm SOACS -> Identity (Stm GPU)) -> Stm SOACS -> Stm GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS GPU -> Stm SOACS -> Identity (Stm GPU)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm ((SOAC GPU -> Op GPU) -> Rephraser Identity SOACS GPU
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC GPU -> Op GPU
forall rep op. op -> HostOp rep op
OtherOp)

soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU
soacsLambdaToGPU :: Lambda SOACS -> Lambda GPU
soacsLambdaToGPU = Identity (Lambda GPU) -> Lambda GPU
forall a. Identity a -> a
runIdentity (Identity (Lambda GPU) -> Lambda GPU)
-> (Lambda SOACS -> Identity (Lambda GPU))
-> Lambda SOACS
-> Lambda GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS GPU
-> Lambda SOACS -> Identity (Lambda GPU)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda ((SOAC GPU -> Op GPU) -> Rephraser Identity SOACS GPU
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC GPU -> Op GPU
forall rep op. op -> HostOp rep op
OtherOp)

scopeForSOACs :: Scope GPU -> Scope SOACS
scopeForSOACs :: Scope GPU -> Scope SOACS
scopeForSOACs = Scope GPU -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope

scopeForGPU :: Scope SOACS -> Scope GPU
scopeForGPU :: Scope SOACS -> Scope GPU
scopeForGPU = Scope SOACS -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope