{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE TypeFamilies #-} -- | This pass attempts to lower allocations as far towards the bottom of their -- body as possible. module Futhark.Pass.LowerAllocations ( lowerAllocationsSeqMem, lowerAllocationsGPUMem, lowerAllocationsMCMem, ) where import Control.Monad.Reader import Data.Function ((&)) import Data.Map qualified as M import Data.Sequence (Seq (..)) import Data.Sequence qualified as Seq import Futhark.IR.GPUMem import Futhark.IR.MCMem import Futhark.IR.SeqMem import Futhark.Pass (Pass (..)) lowerAllocationsSeqMem :: Pass SeqMem SeqMem lowerAllocationsSeqMem = Pass "lower allocations" "lower allocations" $ \prog@Prog {progFuns} -> pure $ prog { progFuns = fmap ( \f@FunDef {funDefBody} -> f {funDefBody = runReader (lowerAllocationsInBody funDefBody) (Env pure)} ) progFuns } lowerAllocationsGPUMem :: Pass GPUMem GPUMem lowerAllocationsGPUMem = Pass "lower allocations gpu" "lower allocations gpu" $ \prog@Prog {progFuns} -> pure $ prog { progFuns = fmap ( \f@FunDef {funDefBody} -> f {funDefBody = runReader (lowerAllocationsInBody funDefBody) (Env lowerAllocationsInHostOp)} ) progFuns } lowerAllocationsMCMem :: Pass MCMem MCMem lowerAllocationsMCMem = Pass "lower allocations mc" "lower allocations mc" $ \prog@Prog {progFuns} -> pure $ prog { progFuns = fmap ( \f@FunDef {funDefBody} -> f {funDefBody = runReader (lowerAllocationsInBody funDefBody) (Env lowerAllocationsInMCOp)} ) progFuns } newtype Env inner = Env {onInner :: inner -> LowerM inner inner} type LowerM inner a = Reader (Env inner) a lowerAllocationsInBody :: (Mem rep inner, LetDec rep ~ LetDecMem) => Body rep -> LowerM (inner rep) (Body rep) lowerAllocationsInBody body = do stms <- lowerAllocationsInStms (bodyStms body) mempty mempty pure $ body {bodyStms = stms} lowerAllocationsInStms :: (Mem rep inner, LetDec rep ~ LetDecMem) => -- | The input stms Stms rep -> -- | The allocations currently being lowered M.Map VName (Stm rep) -> -- | The other statements processed so far Stms rep -> LowerM (inner rep) (Stms rep) lowerAllocationsInStms Empty allocs acc = pure $ acc <> Seq.fromList (M.elems allocs) lowerAllocationsInStms (stm@(Let (Pat [PatElem vname _]) _ (Op (Alloc _ _))) :<| stms) allocs acc = lowerAllocationsInStms stms (M.insert vname stm allocs) acc lowerAllocationsInStms (stm0@(Let _ _ (Op (Inner inner))) :<| stms) alloc0 acc0 = do on_inner <- asks onInner inner' <- on_inner inner let stm = stm0 {stmExp = Op $ Inner inner'} (alloc, acc) = insertLoweredAllocs (freeIn stm0) alloc0 acc0 lowerAllocationsInStms stms alloc (acc :|> stm) lowerAllocationsInStms (stm@(Let _ _ (Match cond_ses cases body dec)) :<| stms) alloc acc = do cases' <- mapM (\(Case pat b) -> Case pat <$> lowerAllocationsInBody b) cases body' <- lowerAllocationsInBody body let stm' = stm {stmExp = Match cond_ses cases' body' dec} (alloc', acc') = insertLoweredAllocs (freeIn stm) alloc acc lowerAllocationsInStms stms alloc' (acc' :|> stm') lowerAllocationsInStms (stm@(Let _ _ (DoLoop params form body)) :<| stms) alloc acc = do body' <- lowerAllocationsInBody body let stm' = stm {stmExp = DoLoop params form body'} (alloc', acc') = insertLoweredAllocs (freeIn stm) alloc acc lowerAllocationsInStms stms alloc' (acc' :|> stm') lowerAllocationsInStms (stm :<| stms) alloc acc = do let (alloc', acc') = insertLoweredAllocs (freeIn stm) alloc acc lowerAllocationsInStms stms alloc' (acc' :|> stm) insertLoweredAllocs :: Names -> M.Map VName (Stm rep) -> Stms rep -> (M.Map VName (Stm rep), Stms rep) insertLoweredAllocs frees alloc acc = frees `namesIntersection` namesFromList (M.keys alloc) & namesToList & foldl ( \(alloc', acc') name -> ( M.delete name alloc', acc' :|> alloc' M.! name ) ) (alloc, acc) lowerAllocationsInSegOp :: (Mem rep inner, LetDec rep ~ LetDecMem) => SegOp lvl rep -> LowerM (inner rep) (SegOp lvl rep) lowerAllocationsInSegOp (SegMap lvl sp tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegMap lvl sp tps $ body {kernelBodyStms = stms} lowerAllocationsInSegOp (SegRed lvl sp binops tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegRed lvl sp binops tps $ body {kernelBodyStms = stms} lowerAllocationsInSegOp (SegScan lvl sp binops tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegScan lvl sp binops tps $ body {kernelBodyStms = stms} lowerAllocationsInSegOp (SegHist lvl sp histops tps body) = do stms <- lowerAllocationsInStms (kernelBodyStms body) mempty mempty pure $ SegHist lvl sp histops tps $ body {kernelBodyStms = stms} lowerAllocationsInHostOp :: HostOp NoOp GPUMem -> LowerM (HostOp NoOp GPUMem) (HostOp NoOp GPUMem) lowerAllocationsInHostOp (SegOp op) = SegOp <$> lowerAllocationsInSegOp op lowerAllocationsInHostOp op = pure op lowerAllocationsInMCOp :: MCOp NoOp MCMem -> LowerM (MCOp NoOp MCMem) (MCOp NoOp MCMem) lowerAllocationsInMCOp (ParOp par op) = ParOp <$> traverse lowerAllocationsInSegOp par <*> lowerAllocationsInSegOp op lowerAllocationsInMCOp op = pure op