{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Pass.ExtractKernels.ToKernels
       ( getSize
       , segThread

       , soacsLambdaToKernels
       , soacsStmToKernels
       , scopeForKernels
       , scopeForSOACs
       )
       where

import Control.Monad.Identity
import Data.List ()

import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.SOACS.SOAC as SOAC
import Futhark.IR.Kernels
import Futhark.Tools

getSize :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
           String -> SizeClass -> m SubExp
getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
  Name
size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
desc (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

segThread :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
             String -> m SegLevel
segThread :: String -> m SegLevel
segThread String
desc =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
    (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> m SubExp -> m (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_num_groups") SizeClass
SizeNumGroups)
    m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> m SubExp -> m (Count GroupSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup)
    m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
SegVirt

injectSOACS :: (Monad m,
                SameScope from to,
                ExpDec from ~ ExpDec to,
                BodyDec from ~ BodyDec to,
                RetType from ~ RetType to,
                BranchType from ~ BranchType to,
                Op from ~ SOAC from) =>
               (SOAC to -> Op to) -> Rephraser m from to
injectSOACS :: (SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f = Rephraser :: forall (m :: * -> *) from to.
(ExpDec from -> m (ExpDec to))
-> (LetDec from -> m (LetDec to))
-> (FParamInfo from -> m (FParamInfo to))
-> (LParamInfo from -> m (LParamInfo to))
-> (BodyDec from -> m (BodyDec to))
-> (RetType from -> m (RetType to))
-> (BranchType from -> m (BranchType to))
-> (Op from -> m (Op to))
-> Rephraser m from to
Rephraser { rephraseExpLore :: ExpDec from -> m (ExpDec to)
rephraseExpLore = ExpDec from -> m (ExpDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseBodyLore :: BodyDec from -> m (BodyDec to)
rephraseBodyLore = BodyDec from -> m (BodyDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseLetBoundLore :: LetDec from -> m (LetDec to)
rephraseLetBoundLore = LetDec from -> m (LetDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseFParamLore :: FParamInfo from -> m (FParamInfo to)
rephraseFParamLore = FParamInfo from -> m (FParamInfo to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseLParamLore :: LParamInfo from -> m (LParamInfo to)
rephraseLParamLore = LParamInfo from -> m (LParamInfo to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseOp :: Op from -> m (Op to)
rephraseOp = (SOAC to -> Op to) -> m (SOAC to) -> m (Op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SOAC to -> Op to
f (m (SOAC to) -> m (Op to))
-> (SOAC from -> m (SOAC to)) -> SOAC from -> m (Op to)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC from -> m (SOAC to)
onSOAC
                          , rephraseRetType :: RetType from -> m (RetType to)
rephraseRetType = RetType from -> m (RetType to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          , rephraseBranchType :: BranchType from -> m (BranchType to)
rephraseBranchType = BranchType from -> m (BranchType to)
forall (m :: * -> *) a. Monad m => a -> m a
return
                          }
  where onSOAC :: SOAC from -> m (SOAC to)
onSOAC = SOACMapper from to m -> SOAC from -> m (SOAC to)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
SOAC.mapSOACM SOACMapper from to m
mapper
        mapper :: SOACMapper from to m
mapper = SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOAC.SOACMapper { mapOnSOACSubExp :: SubExp -> m SubExp
SOAC.mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return
                                 , mapOnSOACVName :: VName -> m VName
SOAC.mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return
                                 , mapOnSOACLambda :: Lambda from -> m (Lambda to)
SOAC.mapOnSOACLambda = Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda (Rephraser m from to -> Lambda from -> m (Lambda to))
-> Rephraser m from to -> Lambda from -> m (Lambda to)
forall a b. (a -> b) -> a -> b
$ (SOAC to -> Op to) -> Rephraser m from to
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f
                                 }

soacsStmToKernels :: Stm SOACS -> Stm Kernels
soacsStmToKernels :: Stm SOACS -> Stm Kernels
soacsStmToKernels = Identity (Stm Kernels) -> Stm Kernels
forall a. Identity a -> a
runIdentity (Identity (Stm Kernels) -> Stm Kernels)
-> (Stm SOACS -> Identity (Stm Kernels))
-> Stm SOACS
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS Kernels
-> Stm SOACS -> Identity (Stm Kernels)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC Kernels -> Op Kernels
forall lore op. op -> HostOp lore op
OtherOp)

soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels = Identity (Lambda Kernels) -> Lambda Kernels
forall a. Identity a -> a
runIdentity (Identity (Lambda Kernels) -> Lambda Kernels)
-> (Lambda SOACS -> Identity (Lambda Kernels))
-> Lambda SOACS
-> Lambda Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS Kernels
-> Lambda SOACS -> Identity (Lambda Kernels)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC Kernels -> Op Kernels
forall lore op. op -> HostOp lore op
OtherOp)

scopeForSOACs :: Scope Kernels -> Scope SOACS
scopeForSOACs :: Scope Kernels -> Scope SOACS
scopeForSOACs = Scope Kernels -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope

scopeForKernels :: Scope SOACS -> Scope Kernels
scopeForKernels :: Scope SOACS -> Scope Kernels
scopeForKernels = Scope SOACS -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope