{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.Representation.Kernels.Simplify
( simplifyKernels
, simplifyLambda
, simplifyKernelOp
)
where
import Control.Monad
import Data.Foldable
import Data.List
import Data.Maybe
import qualified Data.Map.Strict as M
import Futhark.Representation.Kernels
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Rules
import Futhark.Optimise.Simplify.Lore
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass
import Futhark.Representation.SOACS.Simplify (simplifySOAC)
import qualified Futhark.Optimise.Simplify as Simplify
import Futhark.Optimise.Simplify.Rule
import qualified Futhark.Analysis.SymbolTable as ST
import qualified Futhark.Analysis.UsageTable as UT
import Futhark.Util (chunks)
import qualified Futhark.Transform.FirstOrderTransform as FOT
simpleKernels :: Simplify.SimpleOps Kernels
simpleKernels = Simplify.bindableSimpleOps $ simplifyKernelOp simplifySOAC
simplifyKernels :: Prog Kernels -> PassM (Prog Kernels)
simplifyKernels =
Simplify.simplifyProg simpleKernels kernelRules Simplify.noExtraHoistBlockers
simplifyLambda :: (HasScope Kernels m, MonadFreshNames m) =>
Lambda Kernels -> [Maybe VName] -> m (Lambda Kernels)
simplifyLambda =
Simplify.simplifyLambda simpleKernels kernelRules Engine.noExtraHoistBlockers
simplifyKernelOp :: (Engine.SimplifiableLore lore,
BodyAttr lore ~ ()) =>
Simplify.SimplifyOp lore op
-> HostOp lore op
-> Engine.SimpleM lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp f (OtherOp op) = do
(op', stms) <- f op
return (OtherOp op', stms)
simplifyKernelOp _ (SegOp (SegMap lvl space ts kbody)) = do
(lvl', space', ts') <- Engine.simplify (lvl, space, ts)
(kbody', body_hoisted) <- simplifyKernelBody space kbody
return (SegOp $ SegMap lvl' space' ts' kbody',
body_hoisted)
simplifyKernelOp _ (SegOp (SegRed lvl space reds ts kbody)) = do
(lvl', space', ts') <- Engine.simplify (lvl, space, ts)
(reds', reds_hoisted) <- fmap unzip $ forM reds $ \(SegRedOp comm lam nes shape) -> do
(lam', hoisted) <-
Engine.localVtable (<>scope_vtable) $
Engine.simplifyLambda lam $ replicate (length nes * 2) Nothing
shape' <- Engine.simplify shape
nes' <- mapM Engine.simplify nes
return (SegRedOp comm lam' nes' shape', hoisted)
(kbody', body_hoisted) <- simplifyKernelBody space kbody
return (SegOp $ SegRed lvl' space' reds' ts' kbody',
mconcat reds_hoisted <> body_hoisted)
where scope = scopeOfSegSpace space
scope_vtable = ST.fromScope scope
simplifyKernelOp _ (SegOp (SegScan lvl space scan_op nes ts kbody)) = do
lvl' <- Engine.simplify lvl
(space', scan_op', nes', ts', kbody', hoisted) <-
simplifyRedOrScan space scan_op nes ts kbody
return (SegOp $ SegScan lvl' space' scan_op' nes' ts' kbody',
hoisted)
simplifyKernelOp _ (SegOp (SegGenRed lvl space ops ts kbody)) = do
(lvl', space', ts') <- Engine.simplify (lvl, space, ts)
(ops', ops_hoisted) <- fmap unzip $ forM ops $
\(GenReduceOp w arrs nes dims lam) -> do
w' <- Engine.simplify w
arrs' <- Engine.simplify arrs
nes' <- Engine.simplify nes
dims' <- Engine.simplify dims
(lam', op_hoisted) <-
Engine.localVtable (<>scope_vtable) $
Engine.simplifyLambda lam $
replicate (length nes * 2) Nothing
return (GenReduceOp w' arrs' nes' dims' lam',
op_hoisted)
(kbody', body_hoisted) <- simplifyKernelBody space kbody
return (SegOp $ SegGenRed lvl' space' ops' ts' kbody',
mconcat ops_hoisted <> body_hoisted)
where scope = scopeOfSegSpace space
scope_vtable = ST.fromScope scope
simplifyKernelOp _ (SplitSpace o w i elems_per_thread) =
(,) <$> (SplitSpace <$> Engine.simplify o <*> Engine.simplify w
<*> Engine.simplify i <*> Engine.simplify elems_per_thread)
<*> pure mempty
simplifyKernelOp _ (GetSize key size_class) =
return (GetSize key size_class, mempty)
simplifyKernelOp _ (GetSizeMax size_class) =
return (GetSizeMax size_class, mempty)
simplifyKernelOp _ (CmpSizeLe key size_class x) = do
x' <- Engine.simplify x
return (CmpSizeLe key size_class x', mempty)
simplifyRedOrScan :: (Engine.SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace
-> Lambda lore -> [SubExp] -> [Type]
-> KernelBody lore
-> Simplify.SimpleM lore
(SegSpace, Lambda (Wise lore), [SubExp], [Type], KernelBody (Wise lore),
Stms (Wise lore))
simplifyRedOrScan space scan_op nes ts kbody = do
space' <- Engine.simplify space
nes' <- mapM Engine.simplify nes
ts' <- mapM Engine.simplify ts
(scan_op', scan_op_hoisted) <-
Engine.localVtable (<>scope_vtable) $
Engine.simplifyLambda scan_op $ replicate (length nes * 2) Nothing
(kbody', body_hoisted) <- simplifyKernelBody space kbody
return (space', scan_op', nes', ts', kbody',
scan_op_hoisted <> body_hoisted)
where scope = scopeOfSegSpace space
scope_vtable = ST.fromScope scope
simplifyKernelBody :: (Engine.SimplifiableLore lore, BodyAttr lore ~ ()) =>
SegSpace -> KernelBody lore
-> Engine.SimpleM lore (KernelBody (Wise lore), Stms (Wise lore))
simplifyKernelBody space (KernelBody _ stms res) = do
par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers
((body_stms, body_res), hoisted) <-
Engine.localVtable (<>scope_vtable) $
Engine.localVtable (\vtable -> vtable { ST.simplifyMemory = True }) $
Engine.blockIf (Engine.hasFree bound_here
`Engine.orIf` Engine.isOp
`Engine.orIf` par_blocker
`Engine.orIf` Engine.isConsumed) $
Engine.simplifyStms stms $ do
res' <- Engine.localVtable (ST.hideCertified $ namesFromList $ M.keys $ scopeOf stms) $
mapM Engine.simplify res
return ((res', UT.usages $ freeIn res'), mempty)
return (mkWiseKernelBody () body_stms body_res,
hoisted)
where scope_vtable = ST.fromScope $ scopeOfSegSpace space
bound_here = namesFromList $ M.keys $ scopeOfSegSpace space
mkWiseKernelBody :: (Attributes lore, CanBeWise (Op lore)) =>
BodyAttr lore -> Stms (Wise lore) -> [KernelResult] -> KernelBody (Wise lore)
mkWiseKernelBody attr bnds res =
let Body attr' _ _ = mkWiseBody attr bnds res_vs
in KernelBody attr' bnds res
where res_vs = map kernelResultSubExp res
instance Engine.Simplifiable SplitOrdering where
simplify SplitContiguous =
return SplitContiguous
simplify (SplitStrided stride) =
SplitStrided <$> Engine.simplify stride
instance Engine.Simplifiable SegLevel where
simplify (SegThread num_groups group_size virt) =
SegThread <$> traverse Engine.simplify num_groups <*>
traverse Engine.simplify group_size <*> pure virt
simplify (SegGroup num_groups group_size virt) =
SegGroup <$> traverse Engine.simplify num_groups <*>
traverse Engine.simplify group_size <*> pure virt
simplify (SegThreadScalar num_groups group_size virt) =
SegThreadScalar <$> traverse Engine.simplify num_groups <*>
traverse Engine.simplify group_size <*> pure virt
instance Engine.Simplifiable SegSpace where
simplify (SegSpace phys dims) =
SegSpace phys <$> mapM (traverse Engine.simplify) dims
instance Engine.Simplifiable KernelResult where
simplify (Returns what) =
Returns <$> Engine.simplify what
simplify (WriteReturns ws a res) =
WriteReturns <$> Engine.simplify ws <*> Engine.simplify a <*> Engine.simplify res
simplify (ConcatReturns o w pte what) =
ConcatReturns
<$> Engine.simplify o
<*> Engine.simplify w
<*> Engine.simplify pte
<*> Engine.simplify what
simplify (TileReturns dims what) =
TileReturns <$> Engine.simplify dims <*> Engine.simplify what
instance BinderOps (Wise Kernels) where
mkExpAttrB = bindableMkExpAttrB
mkBodyB = bindableMkBodyB
mkLetNamesB = bindableMkLetNamesB
kernelRules :: RuleBook (Wise Kernels)
kernelRules = standardRules <>
ruleBook [ RuleOp removeInvariantKernelResults
, RuleOp mergeSegRedOps
, RuleOp redomapIotaToLoop ]
[ RuleOp distributeKernelResults
, RuleBasicOp removeUnnecessaryCopy]
removeInvariantKernelResults :: TopDownRuleOp (Wise Kernels)
removeInvariantKernelResults vtable (Pattern [] kpes) attr
(SegOp (SegMap lvl space ts (KernelBody _ kstms kres))) = Simplify $ do
case lvl of
SegThreadScalar{} -> cannotSimplify
_ -> return ()
(ts', kpes', kres') <-
unzip3 <$> filterM checkForInvarianceResult (zip3 ts kpes kres)
when (kres == kres')
cannotSimplify
addStm $ Let (Pattern [] kpes') attr $ Op $ SegOp $ SegMap lvl space ts' $
mkWiseKernelBody () kstms kres'
where isInvariant Constant{} = True
isInvariant (Var v) = isJust $ ST.lookup v vtable
checkForInvarianceResult (_, pe, Returns se)
| isInvariant se = do
letBindNames_ [patElemName pe] $
BasicOp $ Replicate (Shape $ segSpaceDims space) se
return False
checkForInvarianceResult _ =
return True
removeInvariantKernelResults _ _ _ _ = Skip
distributeKernelResults :: BottomUpRuleOp (Wise Kernels)
distributeKernelResults (vtable, used)
(Pattern [] kpes) attr (SegOp (SegMap lvl space kts (KernelBody _ kstms kres))) = Simplify $ do
(kpes', kts', kres', kstms_rev) <- localScope (scopeOfSegSpace space) $
foldM distribute (kpes, kts, kres, []) kstms
when (kpes' == kpes)
cannotSimplify
addStm $ Let (Pattern [] kpes') attr $ Op $ SegOp $
SegMap lvl space kts' $ mkWiseKernelBody () (stmsFromList $ reverse kstms_rev) kres'
where
free_in_kstms = fold $ fmap freeIn kstms
distribute (kpes', kts', kres', kstms_rev) bnd
| Let (Pattern [] [pe]) _ (BasicOp (Index arr slice)) <- bnd,
space_slice <- map (DimFix . Var . fst) $ unSegSpace space,
space_slice `isPrefixOf` slice,
remaining_slice <- drop (length space_slice) slice,
all (isJust . flip ST.lookup vtable) $ namesToList $
freeIn arr <> freeIn remaining_slice,
Just (kpe, kpes'', kts'', kres'') <- isResult kpes' kts' kres' pe = do
let outer_slice = map (\d -> DimSlice
(constant (0::Int32))
d
(constant (1::Int32))) $
segSpaceDims space
index kpe' = letBind_ (Pattern [] [kpe']) $ BasicOp $ Index arr $
outer_slice <> remaining_slice
if patElemName kpe `UT.isConsumed` used
then do precopy <- newVName $ baseString (patElemName kpe) <> "_precopy"
index kpe { patElemName = precopy }
letBind_ (Pattern [] [kpe]) $ BasicOp $ Copy precopy
else index kpe
return (kpes'', kts'', kres'',
if patElemName pe `nameIn` free_in_kstms
then bnd : kstms_rev
else kstms_rev)
distribute (kpes', kts', kres', kstms_rev) bnd =
return (kpes', kts', kres', bnd : kstms_rev)
isResult kpes' kts' kres' pe =
case partition matches $ zip3 kpes' kts' kres' of
([(kpe,_,_)], kpes_and_kres)
| (kpes'', kts'', kres'') <- unzip3 kpes_and_kres ->
Just (kpe, kpes'', kts'', kres'')
_ -> Nothing
where matches (_, _, kre) = kre == Returns (Var $ patElemName pe)
distributeKernelResults _ _ _ _ = Skip
mergeSegRedOps :: TopDownRuleOp (Wise Kernels)
mergeSegRedOps _ (Pattern [] pes) _ (SegOp (SegRed lvl space ops ts kbody))
| length ops > 1,
op_groupings <- groupBy sameShape $ zip ops $ chunks (map (length . segRedNeutral) ops) $
zip3 red_pes red_ts red_res,
any ((>1) . length) op_groupings = Simplify $ do
let (ops', aux) = unzip $ mapMaybe combineOps op_groupings
(red_pes', red_ts', red_res') = unzip3 $ concat aux
pes' = red_pes' ++ map_pes
ts' = red_ts' ++ map_ts
kbody' = kbody { kernelBodyResult = red_res' ++ map_res }
letBind_ (Pattern [] pes') $ Op $ SegOp $ SegRed lvl space ops' ts' kbody'
where (red_pes, map_pes) = splitAt (segRedResults ops) pes
(red_ts, map_ts) = splitAt (segRedResults ops) ts
(red_res, map_res) = splitAt (segRedResults ops) $ kernelBodyResult kbody
sameShape (op1, _) (op2, _) = segRedShape op1 == segRedShape op2
combineOps :: [(SegRedOp (Wise Kernels), [a])]
-> Maybe (SegRedOp (Wise Kernels), [a])
combineOps [] = Nothing
combineOps (x:xs) = Just $ foldl' combine x xs
combine (op1, op1_aux) (op2, op2_aux) =
let lam1 = segRedLambda op1
lam2 = segRedLambda op2
(op1_xparams, op1_yparams) =
splitAt (length (segRedNeutral op1)) $ lambdaParams lam1
(op2_xparams, op2_yparams) =
splitAt (length (segRedNeutral op2)) $ lambdaParams lam2
lam = Lambda { lambdaParams = op1_xparams ++ op2_xparams ++
op1_yparams ++ op2_yparams
, lambdaReturnType = lambdaReturnType lam1 ++ lambdaReturnType lam2
, lambdaBody =
mkBody (bodyStms (lambdaBody lam1) <> bodyStms (lambdaBody lam2)) $
bodyResult (lambdaBody lam1) <> bodyResult (lambdaBody lam2)
}
in (SegRedOp { segRedComm = segRedComm op1 <> segRedComm op2
, segRedLambda = lam
, segRedNeutral = segRedNeutral op1 ++ segRedNeutral op2
, segRedShape = segRedShape op1
},
op1_aux ++ op2_aux)
mergeSegRedOps _ _ _ _ = Skip
redomapIotaToLoop :: TopDownRuleOp (Wise Kernels)
redomapIotaToLoop vtable pat aux (OtherOp soac@(Screma _ form [arr]))
| Just _ <- isRedomapSOAC form,
Just (Iota{}, _) <- ST.lookupBasicOp arr vtable =
Simplify $ certifying (stmAuxCerts aux) $ FOT.transformSOAC pat soac
redomapIotaToLoop _ _ _ _ =
Skip