module Data.Array.Repa.Plugin.ToGHC.Wrap ( wrapLowered , unwrapResult) where import Data.Array.Repa.Plugin.ToGHC.Var import Data.Array.Repa.Plugin.GHC.Pretty () import qualified BasicTypes as G import qualified CoreSyn as G import qualified DataCon as G import qualified Type as G import qualified TypeRep as G import qualified TysPrim as G import qualified TysWiredIn as G import qualified MkId as G import qualified UniqSupply as G import Control.Monad -- | Make a wrapper to call a lowered version of a function from the original -- binding. We need to unsafely pass it the world token, as well as marshall -- between boxed and unboxed types. wrapLowered :: G.Type -- ^ Type of original version. -> G.Type -- ^ Type of lowered version. -> [Either G.Var G.CoreExpr] -- ^ Lambda bound variables in wrapper. -> G.Var -- ^ Name of lowered version. -> G.UniqSM G.CoreExpr wrapLowered tOrig tLowered vsParam vLowered -- Decend into foralls. -- Bind the type argument with a new var so we can pass it to -- the lowered function. | G.ForAllTy vOrig tOrig' <- tOrig , G.ForAllTy _ tLowered' <- tLowered = do let vsParam' = Left vOrig : vsParam xBody <- wrapLowered tOrig' tLowered' vsParam' vLowered return $ G.Lam vOrig xBody -- If the type of the lowered function says it needs -- the realworld token, then just give it one. -- This effectively unsafePerformIOs it. | G.FunTy tLowered1 tLowered2 <- tLowered , G.TyConApp tcState _ <- tLowered1 , tcState == G.statePrimTyCon = do let vsParam' = Right (G.Var G.realWorldPrimId) : vsParam wrapLowered tOrig tLowered2 vsParam' vLowered -- Descend into functions. -- Bind the argument with a new var so we can pass it to the lowered -- function. | G.FunTy tOrig1 tOrig2 <- tOrig , G.FunTy tLowered1 tLowered2 <- tLowered = do v' <- newDummyVar "arg" tOrig1 -- Convert from type 'tOrig1' to 'tLowered1' arg' <- unwrapResult tLowered1 tOrig1 (G.Var v') let vsParam' = Right arg' : vsParam xBody <- wrapLowered tOrig2 tLowered2 vsParam' vLowered return $ G.Lam v' xBody -- We've decended though all the foralls and lambdas and now need -- to call the actual lowered function, and marshall its result. | otherwise = do -- Arguments to pass to the lowered function. let xsArg = map (either (G.Type . G.TyVarTy) id) vsParam -- Actual call to the lowered function. let xLowered = foldl G.App (G.Var vLowered) $ reverse xsArg callLowered tOrig tLowered xLowered -- | Make the call site for the lowered function. callLowered :: G.Type -- ^ Type of result for original unlowered version. -> G.Type -- ^ Type of result for lowered version. -> G.CoreExpr -- ^ Exp that calls the lowered version. -> G.UniqSM G.CoreExpr callLowered tOrig tLowered xLowered -- Assume this function returns a (# World#, ts.. #) -- TODO: check this. | G.TyConApp _ (_tWorld : tsVal) <- tLowered = do vScrut <- newDummyVar "scrut" tLowered vWorld <- newDummyVar "world" G.realWorldStatePrimTy vsVal <- zipWithM (\i t -> newDummyVar ("val" ++ show i) t) [0 :: Int ..] tsVal -- Unwrap the actual result value. let tOrigVal = tOrig let tsLoweredVal = tsVal xResult <- unwrapResultBits tOrigVal tsLoweredVal (map G.Var vsVal) return $ G.Case xLowered vScrut tOrig [ (G.DataAlt (G.tupleCon G.UnboxedTuple (1 + length tsVal)) , (vWorld : vsVal) , xResult) ] | otherwise = error "repa-plugin.Wrap.callLowered: no match" unwrapResultBits :: G.Type -- ^ Type of result for original version. -> [G.Type] -- ^ Types of arguments lowered arguments -> [G.CoreExpr] -- ^ Types of components -> G.UniqSM G.CoreExpr unwrapResultBits tOrig tsBits xsBits | [tBit] <- tsBits , [xBit] <- xsBits = unwrapResult tOrig tBit xBit | G.TyConApp tcTup tsOrig <- tOrig , n <- length tsOrig , G.tupleTyCon G.BoxedTuple n == tcTup = do xsResult <- mapM (\(tOrig', tLowered, xBit) -> unwrapResult tOrig' tLowered xBit) $ zip3 tsOrig tsBits xsBits return $ G.mkConApp (G.tupleCon G.BoxedTuple n) (map G.Type tsOrig ++ xsResult) | otherwise = error "unwrapResultBits: failed" unwrapResult :: G.Type -- ^ Type of result for original unlowered version. -> G.Type -- ^ Type of result for lowered version. -> G.CoreExpr -- ^ Expression for result value. -> G.UniqSM G.CoreExpr unwrapResult tOrig tLowered xResult | G.TyConApp tcInt [] <- tOrig , tcInt == G.intTyCon , G.TyConApp tcIntU [] <- tLowered , tcIntU == G.intPrimTyCon -- TODO: do a proper check. -- Is this supposed to be a TyLit? = return $ G.App (G.Var (G.dataConWorkId G.intDataCon)) xResult | G.TyConApp tcIntU [] <- tOrig , tcIntU == G.intPrimTyCon , G.TyConApp tcInt [] <- tLowered , tcInt == G.intTyCon = do -- Case on the int constructor vScrut <- newDummyVar "scrut" tLowered v <- newDummyVar "v" tOrig return $ G.Case xResult vScrut tOrig [ (G.DataAlt G.intDataCon , [v] , G.Var v)] -- Original is a boxed tuple and lowered version is unboxed: -- raise to a boxed tuple, boxing its elements too. | G.TyConApp tcTup tins <- tOrig , G.TyConApp tcUnb touts <- tLowered , n <- length tins , G.tupleTyCon G.BoxedTuple n == tcTup , G.tupleTyCon G.UnboxedTuple n == tcUnb = do -- Case on the unboxed tuple, raise the elements, then create a boxed tuple vScrut <- newDummyVar "scrut" tLowered vs <- mapM (newDummyVar "v") touts let unwrap (t,t',v) = unwrapResult t t' (G.Var v) xs <- mapM unwrap (zip3 tins touts vs) return (G.Case xResult vScrut tOrig [ (G.DataAlt (G.tupleCon G.UnboxedTuple n) , vs, G.mkConApp (G.tupleCon G.BoxedTuple n) (map G.Type tins ++ xs))]) -- Convert boxed tuple to unboxed, maybe unbox its elements too | G.TyConApp tcUnb tins <- tOrig , G.TyConApp tcTup touts <- tLowered , n <- length tins , G.tupleTyCon G.UnboxedTuple n == tcUnb , G.tupleTyCon G.BoxedTuple n == tcTup = do -- Case on the unboxed tuple, raise the elements, then create a boxed tuple vScrut <- newDummyVar "scrut" tLowered vs <- mapM (newDummyVar "v") touts let unwrap (t,t',v) = unwrapResult t t' (G.Var v) xs <- mapM unwrap (zip3 tins touts vs) return (G.Case xResult vScrut tOrig [ (G.DataAlt (G.tupleCon G.BoxedTuple n) , vs, G.mkConApp (G.tupleCon G.UnboxedTuple n) (map G.Type tins ++ xs))]) | otherwise = return xResult