module Data.Array.Repa.Plugin.ToGHC.Wrap
( wrapLowered
, unwrapResult)
where
import Data.Array.Repa.Plugin.ToGHC.Var
import Data.Array.Repa.Plugin.GHC.Pretty ()
import qualified BasicTypes as G
import qualified CoreSyn as G
import qualified DataCon as G
import qualified Type as G
import qualified TypeRep as G
import qualified TysPrim as G
import qualified TysWiredIn as G
import qualified MkId as G
import qualified UniqSupply as G
import Control.Monad
wrapLowered
:: G.Type
-> G.Type
-> [Either G.Var G.CoreExpr]
-> G.Var
-> G.UniqSM G.CoreExpr
wrapLowered tOrig tLowered vsParam vLowered
| G.ForAllTy vOrig tOrig' <- tOrig
, G.ForAllTy _ tLowered' <- tLowered
= do let vsParam' = Left vOrig : vsParam
xBody <- wrapLowered tOrig' tLowered' vsParam' vLowered
return $ G.Lam vOrig xBody
| G.FunTy tLowered1 tLowered2 <- tLowered
, G.TyConApp tcState _ <- tLowered1
, tcState == G.statePrimTyCon
= do let vsParam' = Right (G.Var G.realWorldPrimId) : vsParam
wrapLowered tOrig tLowered2 vsParam' vLowered
| G.FunTy tOrig1 tOrig2 <- tOrig
, G.FunTy tLowered1 tLowered2 <- tLowered
= do v' <- newDummyVar "arg" tOrig1
arg' <- unwrapResult tLowered1 tOrig1 (G.Var v')
let vsParam' = Right arg' : vsParam
xBody <- wrapLowered tOrig2 tLowered2 vsParam' vLowered
return $ G.Lam v' xBody
| otherwise
= do
let xsArg = map (either (G.Type . G.TyVarTy) id)
vsParam
let xLowered = foldl G.App (G.Var vLowered) $ reverse xsArg
callLowered tOrig tLowered xLowered
callLowered
:: G.Type
-> G.Type
-> G.CoreExpr
-> G.UniqSM G.CoreExpr
callLowered tOrig tLowered xLowered
| G.TyConApp _ (_tWorld : tsVal) <- tLowered
= do
vScrut <- newDummyVar "scrut" tLowered
vWorld <- newDummyVar "world" G.realWorldStatePrimTy
vsVal <- zipWithM (\i t -> newDummyVar ("val" ++ show i) t)
[0 :: Int ..] tsVal
let tOrigVal = tOrig
let tsLoweredVal = tsVal
xResult <- unwrapResultBits
tOrigVal
tsLoweredVal
(map G.Var vsVal)
return $ G.Case xLowered vScrut tOrig
[ (G.DataAlt (G.tupleCon G.UnboxedTuple (1 + length tsVal))
, (vWorld : vsVal)
, xResult) ]
| otherwise
= error "repa-plugin.Wrap.callLowered: no match"
unwrapResultBits
:: G.Type
-> [G.Type]
-> [G.CoreExpr]
-> G.UniqSM G.CoreExpr
unwrapResultBits tOrig tsBits xsBits
| [tBit] <- tsBits
, [xBit] <- xsBits
= unwrapResult tOrig tBit xBit
| G.TyConApp tcTup tsOrig <- tOrig
, n <- length tsOrig
, G.tupleTyCon G.BoxedTuple n == tcTup
= do
xsResult <- mapM (\(tOrig', tLowered, xBit)
-> unwrapResult tOrig' tLowered xBit)
$ zip3 tsOrig tsBits xsBits
return $ G.mkConApp (G.tupleCon G.BoxedTuple n)
(map G.Type tsOrig ++ xsResult)
| otherwise
= error "unwrapResultBits: failed"
unwrapResult
:: G.Type
-> G.Type
-> G.CoreExpr
-> G.UniqSM G.CoreExpr
unwrapResult tOrig tLowered xResult
| G.TyConApp tcInt [] <- tOrig
, tcInt == G.intTyCon
, G.TyConApp tcIntU [] <- tLowered
, tcIntU == G.intPrimTyCon
= return $ G.App (G.Var (G.dataConWorkId G.intDataCon)) xResult
| G.TyConApp tcIntU [] <- tOrig
, tcIntU == G.intPrimTyCon
, G.TyConApp tcInt [] <- tLowered
, tcInt == G.intTyCon
= do
vScrut <- newDummyVar "scrut" tLowered
v <- newDummyVar "v" tOrig
return $ G.Case xResult vScrut tOrig
[ (G.DataAlt G.intDataCon
, [v]
, G.Var v)]
| G.TyConApp tcTup tins <- tOrig
, G.TyConApp tcUnb touts <- tLowered
, n <- length tins
, G.tupleTyCon G.BoxedTuple n == tcTup
, G.tupleTyCon G.UnboxedTuple n == tcUnb
= do
vScrut <- newDummyVar "scrut" tLowered
vs <- mapM (newDummyVar "v") touts
let unwrap (t,t',v)
= unwrapResult t t' (G.Var v)
xs <- mapM unwrap (zip3 tins touts vs)
return (G.Case xResult vScrut tOrig
[ (G.DataAlt (G.tupleCon G.UnboxedTuple n)
, vs,
G.mkConApp (G.tupleCon G.BoxedTuple n)
(map G.Type tins ++ xs))])
| G.TyConApp tcUnb tins <- tOrig
, G.TyConApp tcTup touts <- tLowered
, n <- length tins
, G.tupleTyCon G.UnboxedTuple n == tcUnb
, G.tupleTyCon G.BoxedTuple n == tcTup
= do
vScrut <- newDummyVar "scrut" tLowered
vs <- mapM (newDummyVar "v") touts
let unwrap (t,t',v)
= unwrapResult t t' (G.Var v)
xs <- mapM unwrap (zip3 tins touts vs)
return (G.Case xResult vScrut tOrig
[ (G.DataAlt (G.tupleCon G.BoxedTuple n)
, vs,
G.mkConApp (G.tupleCon G.UnboxedTuple n)
(map G.Type tins ++ xs))])
| otherwise
= return xResult