-- | Convert a loop expressed with the loopn# and guard# combinators into -- a tail recursive loop with accumulators. -- -- ASUMPTIONS: -- -- * No nested loops. -- We could support these, but we don't yet. -- -- * Outer control flow is only defined via the loopn# and guard# -- combinators. -- -- * References don't escape, -- so they're not stored in data structures or captured in closures. -- -- * No aliasing of references, -- so updating ref with a particular name does not affect any other ref. -- -- * Refs holding loop counters for loopn# and entry counters for guard# -- are not written to by any other statements. -- -- The above assumptions are true for code generated with the lowering -- transform, but won't be true for general code, and we don't check for -- violiations of these assumptions. -- module DDC.Core.Flow.Transform.Wind ( RefInfo(..) , windModule) where import DDC.Core.Module import DDC.Core.Exp import DDC.Core.Flow import DDC.Core.Flow.Prim import DDC.Core.Compounds import DDC.Core.Flow.Compounds (tNat, dcNat, dcTupleN, dcBool, tTupleN) import qualified Data.Map as Map import Data.Map (Map) ------------------------------------------------------------------------------- -- | Current information for a reference. data RefInfo = RefInfo { refInfoName :: Name , refInfoType :: Type Name , refInfoCurrent :: Name , refInfoVersionNumber :: Int } data RefMap = RefMap (Map Name RefInfo) refMapZero :: RefMap refMapZero = RefMap Map.empty refMapElems :: RefMap -> [RefInfo] refMapElems (RefMap mm) = Map.elems mm -- | Insert a new `RefInfo` record into the map. insertRefInfo :: RefInfo -> RefMap -> RefMap insertRefInfo info (RefMap mm) = RefMap (Map.insert (refInfoName info) info mm) -- | Lookup a `RefInfo` record from the map. lookupRefInfo :: RefMap -> Name -> Maybe RefInfo lookupRefInfo (RefMap mm) n = Map.lookup n mm -- | Get the name of the current version of a value from a `RefInfo`. nameOfRefInfo :: RefInfo -> Maybe Name nameOfRefInfo info = Just $ NameVarMod (refInfoName info) (show $ refInfoVersionNumber info) -- | Bump the version number of a `RefInfo` bumpVersionOfRefInfo :: RefInfo -> RefInfo bumpVersionOfRefInfo info = info { refInfoVersionNumber = refInfoVersionNumber info + 1 } -- | Bump the version number of one element of a `RefMap`. bumpVersionInRefMap :: Name -> RefMap -> RefMap bumpVersionInRefMap n (RefMap mm) = RefMap $ Map.update (Just . bumpVersionOfRefInfo) n mm -- | Bump the version numbers of all elements of a `RefMap`. bumpAllVersionsInRefMap :: RefMap -> RefMap bumpAllVersionsInRefMap mm = foldr bumpVersionInRefMap mm $ map refInfoName $ refMapElems mm ------------------------------------------------------------------------------- data Context -- | We're currently in the body of a loop. = ContextLoop { contextLoopName :: Name , contextLoopCounter :: Name , contextLoopAccs :: [Name] } -- | We're currently in the body of a guard. | ContextGuard { -- | Name of the entry counter, -- the number of times this guard has matched. contextGuardCounter :: Name -- | Whether we're in the matching or non-matching branch. , contextGuardFlag :: Bool } deriving Show -- | Check if some `Context` is a `ContextLoop`. isContextLoop :: Context -> Bool isContextLoop cc = case cc of ContextLoop{} -> True _ -> False -- | Build a tailcall from the current context. -- This tells us where to go after finishing the body of a loop. makeTailCallFromContexts :: a -> RefMap -> [Context] -> Exp a Name makeTailCallFromContexts a refMap context@(ContextLoop nLoop _ _ : _) = let xLoop = XVar a (UName nLoop) xArgs = slurpArgUpdates a refMap [] context in xApps a xLoop xArgs makeTailCallFromContexts _ _ contexts = error $ unlines [ "ddc-core-flow.makeTailCallFromContexts" , " Can't make a tailcall for this context." , " context = " ++ show contexts ] ------------------------------------------------------------------------------- -- | Slurp expressions to update each of the accumulators of the loop. -- We assume that there have been no other updates to the loop -- counter, and we generated the code ourselves. slurpArgUpdates :: a -> RefMap -> [(Name, Exp a Name)] -> [Context] -> [Exp a Name] slurpArgUpdates a refMap [] (ContextLoop _ nCounter nAccs : more) = let -- Expression to update loop counter. nxCounter' = ( nCounter , xIncrement a (XVar a (UName nCounter)) ) -- Updated accumulators. nxAccs' = [ (nAcc, XVar a (UName nAcc')) | nAcc <- nAccs , let Just info = lookupRefInfo refMap nAcc , let Just nAcc' = nameOfRefInfo info ] in slurpArgUpdates a refMap (nxCounter' : nxAccs') more -- If we're inside the true branch of a guard then update -- the associated entry counter for the guard. slurpArgUpdates a refMap args (ContextGuard nCounter flag : more) | flag == True = let update [] = [] update ((n, x) : args') | n == nCounter = (n, xIncrement a x) : update args' | otherwise = (n, x) : update args' in slurpArgUpdates a refMap (update args) more | otherwise = slurpArgUpdates a refMap args more slurpArgUpdates _ _ _ (ContextLoop{} : _) = error $ unlines [ "ddc-core-flow.slurpArgUpdates" , " Nested loops are not supported." ] slurpArgUpdates _ _ args [] = map snd args -- | Build an expression that increments a natural. xIncrement :: a -> Exp a Name -> Exp a Name xIncrement a xx = xApps a (XVar a (UPrim (NamePrimArith PrimArithAdd) (typePrimArith PrimArithAdd))) [ XType a tNat, xx, XCon a (dcNat 1) ] -- | Build an expression that substracts two integers. xSubInt :: a -> Exp a Name -> Exp a Name -> Exp a Name xSubInt a x1 x2 = xApps a (XVar a (UPrim (NamePrimArith PrimArithSub) (typePrimArith PrimArithSub))) [ XType a tNat, x1, x2] ------------------------------------------------------------------------------- -- | Apply the wind transform to a single module. windModule :: Module () Name -> Module () Name windModule m = let body' = windModuleBodyX (moduleBody m) in m { moduleBody = body' } -- | Do winding in the body of a module. windModuleBodyX :: Exp () Name -> Exp () Name windModuleBodyX xx = case xx of XLet a (LLet b x1) x2 -> let x1' = windBodyX refMapZero [] x1 x2' = windModuleBodyX x2 in XLet a (LLet b x1') x2' XLet a (LRec bxs) x2 -> let bxs' = [(b, windBodyX refMapZero [] x) | (b, x) <- bxs] x2' = windModuleBodyX x2 in XLet a (LRec bxs') x2' XLet a lts x2 -> let x2' = windModuleBodyX x2 in XLet a lts x2' _ -> xx ------------------------------------------------------------------------------- -- | Do winding in the body of a function. windBodyX :: RefMap -- ^ Info about how references are being rewritten. -> [Context] -- ^ What loops and guards we're currently inside. -> Exp () Name -- ^ Rewrite this expression. -> Exp () Name windBodyX refMap context xx = let down = windBodyX refMap context in case xx of ----------------------------------------- -- Detect ref allocation, -- to bind the initial value to a new variable. -- -- ref : Ref# type = new# [type] val -- => ref__0 : type = val -- XLet a (LLet (BName nRef _) x) x2 | Just ( NameOpStore OpStoreNew , [XType _ tElem, xVal] ) <- takeXPrimApps x -> let -- Add the new ref record to the map. info = RefInfo { refInfoName = nRef , refInfoType = tElem , refInfoCurrent = nInit , refInfoVersionNumber = 0 } -- Rewrite the statement that creates a new ref to one -- that just binds the initial value. Just nInit = nameOfRefInfo info refMap' = insertRefInfo info refMap in XLet a (LLet (BName nInit tElem) xVal) (windBodyX refMap' context x2) ----------------------------------------- -- Detect ref read, -- and rewrite to use the current version of the variable. -- val : type = read# [type] ref -- => val : type = ref_N -- XLet a (LLet bResult x) x2 | Just ( NameOpStore OpStoreRead , [XType _ _tElem, XVar _ (UName nRef)] ) <- takeXPrimApps x , Just info <- lookupRefInfo refMap nRef , Just nVal <- nameOfRefInfo info -> XLet a (LLet bResult (XVar a (UName nVal))) (windBodyX refMap context x2) ----------------------------------------- -- Detect ref write, -- to just bind the new value. XLet a (LLet (BNone _) x) x2 | Just ( NameOpStore OpStoreWrite , [XType _ _tElem, XVar _ (UName nRef), xVal]) <- takeXPrimApps x , refMap' <- bumpVersionInRefMap nRef refMap , Just info <- lookupRefInfo refMap' nRef , Just nVal <- nameOfRefInfo info , tVal <- refInfoType info -> XLet a (LLet (BName nVal tVal) xVal) (windBodyX refMap' context x2) ----------------------------------------- -- Detect loop combinator. XLet a (LLet (BNone _) x) x2 | Just ( NameOpControl OpControlLoopN , [ XType _ tK, xLength , XLam _ bIx@(BName nIx _) xBody]) <- takeXPrimApps x -> let -- Name of the new loop function. nLoop = NameVar "loop" bLoop = BName nLoop tLoop uLoop = UName nLoop nLength = NameVarMod nLoop "length" bLength = BName nLength tNat uLength = UName nLength -- RefMap for before the loop, in the body, and after the loop. refMap_init = refMap refMap_body = bumpAllVersionsInRefMap refMap refMap_final = bumpAllVersionsInRefMap refMap_body -- Get binds and bounds for accumluators, -- to use in the body of the loop. bsAccs = [ BName nVar (refInfoType info) | info <- refMapElems refMap_body , let Just nVar = nameOfRefInfo info ] usAccs = takeSubstBoundsOfBinds bsAccs tsAccs = map typeOfBind bsAccs -- The loop function itself will return us a tuple -- containing the final value of all the accumulators. tIndex = typeOfBind bIx tResult = loopResultT tsAccs -- Type of the loop function. tLoop = foldr tFun tResult (tIndex : tsAccs) -- Decend into loop body, -- and remember that we're doing the rewrite inside a loop context. context' = context ++ [ ContextLoop { contextLoopName = nLoop , contextLoopCounter = nIx , contextLoopAccs = map refInfoName $ refMapElems refMap_body } ] xBody' = windBodyX refMap_body context' xBody -- Create the loop driver. -- This is the code that tests for the end-of-loop condition. xDriver = xLams a (bIx : bsAccs) $ XCase a (xSubInt a (XVar a uLength) (XVar a (UName nIx))) [ AAlt (PData (dcNat 0) []) xResult , AAlt PDefault xBody' ] xResult = loopResultX a tsAccs [XVar a u | u <- usAccs] -- Initial values of index and accumulators. xsInit = XCon a (dcNat 0) : [ XVar a (UName nVar) | info <- refMapElems refMap_init , let Just nVar = nameOfRefInfo info ] -- Decend into loop postlude. bsFinal = [ BName nVar (refInfoType info) | info <- refMapElems refMap_final , let Just nVar = nameOfRefInfo info ] x2' = windBodyX refMap_final context x2 in XLet a (LLet bLength (xNatOfRateNat tK xLength)) $ XLet a (LRec [(bLoop, xDriver)]) $ runUnpackLoop a tsAccs -- Types of accumulators. (xApps a (XVar a uLoop) xsInit) -- Expression to invoke loop bsFinal -- Binders for final accumulators x2' -- Continuation expression ----------------------------------------- -- Detect guard combinator. XLet a (LLet (BNone _) x) x2 | Just ( NameOpControl OpControlGuard , [ XVar _ (UName nCountRef) , xFlag , XLam _ bCount xBody ]) <- takeXPrimApps x -> let Just infoCount = lookupRefInfo refMap nCountRef Just nCount = nameOfRefInfo infoCount context' = context ++ [ ContextGuard { contextGuardCounter = nCountRef , contextGuardFlag = True } ] xBody' = XLet a (LLet bCount (XVar a (UName nCount))) $ windBodyX refMap context' xBody in XCase a xFlag [ AAlt (PData (dcBool True) []) xBody' , AAlt PDefault (down x2) ] ----------------------------------------- -- Detect end value. -- If we're inside a loop and hit a Unit at the top-level of the body -- then we know it's time to do the recursive call. XCon a dc | any isContextLoop context , dc == dcUnit -> makeTailCallFromContexts a refMap context ----------------------------------------- -- Enter into both branches of a split. XApp{} | Just ( NameOpControl (OpControlSplit n) , [ XType _ tK, xN, xBranch1, xBranch2 ]) <- takeXPrimApps xx -> let xBranch1' = down xBranch1 xBranch2' = down xBranch2 in xSplit n tK xN xBranch1' xBranch2' -- Boilerplate -------------------------- XVar{} -> xx XCon{} -> xx XLAM a b x -> XLAM a b (down x) XLam a b x -> XLam a b (down x) XApp{} -> xx -- Decend into nest let binding. -- We need to drop the contexts because we never do a tail-call -- from a nested binding. XLet a (LLet b x) x2 -> XLet a (LLet b (windBodyX refMap [] x)) (down x2) XLet a (LRec bxs) x2 -> XLet a (LRec [(b, windBodyX refMap [] x) | (b, x) <- bxs]) (down x2) XLet a lts x2 -> XLet a lts (down x2) XCase{} -> error $ unlines [ "ddc-core-flow.windBodyX" , " case-expressions not supported yet" ] XCast a c x -> let x' = windBodyX refMap context x in XCast a c x' XType{} -> xx XWitness{} -> xx ------------------------------------------------------------------------------- type TypeF = Type Name type ExpF = Exp () Name xNatOfRateNat :: Type Name -> Exp () Name -> Exp () Name xNatOfRateNat tK xR = xApps () (xVarOpConcrete OpConcreteNatOfRateNat) [XType () tK, xR] xVarOpConcrete :: OpConcrete -> Exp () Name xVarOpConcrete op = XVar () (UPrim (NameOpConcrete op) (typeOpConcrete op)) xSplit :: Int -> TypeF -> ExpF -> ExpF -> ExpF -> ExpF xSplit n tK xRN xDownFn xTailFn = xApps () (xVarOpControl $ OpControlSplit n) [ XType () tK, xRN, xDownFn, xTailFn ] xVarOpControl :: OpControl -> Exp () Name xVarOpControl op = XVar () (UPrim (NameOpControl op) (typeOpControl op)) ------------------------------------------------------------------------------- -- | Make the type of a loop result, -- given the types of the accumulators for that loop. -- -- If we have no accumulators, return Unit. -- If we have just one, return that value. -- If more, then package them into a tuple. -- loopResultT :: [Type Name] -> Type Name loopResultT tsAccs = case tsAccs of [] -> tUnit [tAcc] -> tAcc _ -> tTupleN tsAccs -- | Make a loop result, -- given the expressions for the accumulators. loopResultX :: a -> [Type Name] -> [Exp a Name] -> Exp a Name loopResultX a tsAccs xsAccs = case xsAccs of [] -> xUnit a [x] -> x _ -> xApps a (XCon a (dcTupleN $ length tsAccs)) ([XType a t | t <- tsAccs] ++ xsAccs) -- | Call a loop, and unpack its result. runUnpackLoop :: a -> [Type Name] -- ^ Types of accumulators. -> Exp a Name -- ^ Expression to invoke the loop. -> [Bind Name] -- ^ Binders for the accumulated values. -> Exp a Name -- ^ Continuation expression. -> Exp a Name runUnpackLoop a tsAccs xRunLoop bsAcc xCont | [] <- tsAccs = XLet a (LLet (BNone tUnit) xRunLoop) xCont | [_t] <- tsAccs , [b] <- bsAcc = XLet a (LLet b xRunLoop) xCont | otherwise = XCase a xRunLoop [ AAlt (PData (dcTupleN $ length tsAccs) bsAcc) xCont ]