module Data.Array.Repa.Plugin.Primitives
( Primitives (..)
, slurpPrimitives)
where
import Data.Array.Repa.Plugin.ToGHC.Var
import Data.List
import Data.Maybe
import Control.Monad
import qualified HscTypes as G
import qualified CoreSyn as G
import qualified MkCore as G
import qualified DataCon as G
import qualified TyCon as G
import qualified Type as G
import qualified Var as G
import qualified OccName as Occ
import qualified Name as Name
import UniqSupply as G
import qualified UniqSet as US
data Primitives
= Primitives
{ prim_Series :: !G.Type
, prim_Vector :: !G.Type
, prim_Ref :: !G.Type
, prim_addInt :: (G.CoreExpr, G.Type)
, prim_subInt :: (G.CoreExpr, G.Type)
, prim_mulInt :: (G.CoreExpr, G.Type)
, prim_divInt :: (G.CoreExpr, G.Type)
, prim_modInt :: (G.CoreExpr, G.Type)
, prim_remInt :: (G.CoreExpr, G.Type)
, prim_eqInt :: (G.CoreExpr, G.Type)
, prim_neqInt :: (G.CoreExpr, G.Type)
, prim_gtInt :: (G.CoreExpr, G.Type)
, prim_geInt :: (G.CoreExpr, G.Type)
, prim_ltInt :: (G.CoreExpr, G.Type)
, prim_leInt :: (G.CoreExpr, G.Type)
, prim_newRefInt :: (G.CoreExpr, G.Type)
, prim_readRefInt :: (G.CoreExpr, G.Type)
, prim_writeRefInt :: (G.CoreExpr, G.Type)
, prim_newRefInt_T2 :: (G.CoreExpr, G.Type)
, prim_readRefInt_T2 :: (G.CoreExpr, G.Type)
, prim_writeRefInt_T2 :: (G.CoreExpr, G.Type)
, prim_newVectorInt :: (G.CoreExpr, G.Type)
, prim_readVectorInt :: (G.CoreExpr, G.Type)
, prim_writeVectorInt :: (G.CoreExpr, G.Type)
, prim_sliceVectorInt :: (G.CoreExpr, G.Type)
, prim_loop :: (G.CoreExpr, G.Type)
, prim_guard :: (G.CoreExpr, G.Type)
, prim_rateOfSeries :: (G.CoreExpr, G.Type)
, prim_nextInt :: (G.CoreExpr, G.Type)
, prim_nextInt_T2 :: (G.CoreExpr, G.Type)
}
_primitive_types
= [ "Series"
, "Vector"
, "Ref" ]
primitive_ops
=
[ "prim_addInt"
, "prim_subInt"
, "prim_mulInt"
, "prim_divInt"
, "prim_modInt"
, "prim_remInt"
, "prim_eqInt"
, "prim_neqInt"
, "prim_gtInt"
, "prim_geInt"
, "prim_ltInt"
, "prim_leInt"
, "prim_newRefInt"
, "prim_readRefInt"
, "prim_writeRefInt"
, "prim_newRefInt_T2"
, "prim_readRefInt_T2"
, "prim_writeRefInt_T2"
, "prim_newVectorInt"
, "prim_readVectorInt"
, "prim_writeVectorInt"
, "prim_sliceVectorInt"
, "prim_loop"
, "prim_guard"
, "prim_rateOfSeries"
, "prim_nextInt"
, "prim_nextInt_T2" ]
slurpPrimitives
:: G.ModGuts
-> UniqSM (Maybe (Primitives, G.ModGuts))
slurpPrimitives guts
| Just vTable <- listToMaybe
$ mapMaybe findTableFromTopBind
$ G.mg_binds guts
= do
Just (prims, bsMoar) <- makeTable vTable
let hackedGuts
= guts
{ G.mg_binds
= insertAfterTable bsMoar
$ G.mg_binds guts
, G.mg_used_names
= US.addListToUniqSet (G.mg_used_names guts)
$ [G.varName b | G.NonRec b _ <- bsMoar ]}
return $ Just (prims, hackedGuts)
| otherwise
= return Nothing
findTableFromTopBind :: G.CoreBind -> Maybe G.Var
findTableFromTopBind bnd
= case bnd of
G.Rec{} -> Nothing
G.NonRec b _ -> findTableFromBinding b
findTableFromBinding :: G.CoreBndr -> Maybe G.Var
findTableFromBinding b
| strName <- Occ.occNameString
$ Name.nameOccName
$ G.varName b
, strName == "repa_primitives"
= Just b
| otherwise
= Nothing
insertAfterTable :: [G.CoreBind] -> [G.CoreBind] -> [G.CoreBind]
insertAfterTable bsMore bs
= case bs of
[]
-> bs
bb@G.Rec{} : bs'
-> bb : insertAfterTable bsMore bs'
bb@(G.NonRec b _) : bs'
| isJust $ findTableFromBinding b
-> bb : bsMore ++ bs'
| otherwise
-> bb : insertAfterTable bsMore bs'
makeTable
:: G.Var
-> UniqSM (Maybe (Primitives, [G.CoreBind]))
makeTable v
| t <- G.varType v
, Just tc <- G.tyConAppTyCon_maybe t
, G.isAlgTyCon tc
, G.DataTyCon [dc] False <- G.algTyConRhs tc
= do
let labels
= G.dataConFieldLabels dc
let Just tySeries
= liftM (G.dataConFieldType dc)
$ find (\n -> stringOfName n == "prim_Series") labels
let Just tyVector
= liftM (G.dataConFieldType dc)
$ find (\n -> stringOfName n == "prim_Vector") labels
let Just tyRef
= liftM (G.dataConFieldType dc)
$ find (\n -> stringOfName n == "prim_Ref") labels
(bs, selectors) <- makeSelectors v primitive_ops
let get name
= let Just r = lookup name selectors
in r
let table
= Primitives
{ prim_Series = tySeries
, prim_Vector = tyVector
, prim_Ref = tyRef
, prim_addInt = get "prim_addInt"
, prim_subInt = get "prim_subInt"
, prim_mulInt = get "prim_mulInt"
, prim_divInt = get "prim_divInt"
, prim_modInt = get "prim_modInt"
, prim_remInt = get "prim_remInt"
, prim_eqInt = get "prim_eqInt"
, prim_neqInt = get "prim_neqInt"
, prim_gtInt = get "prim_gtInt"
, prim_geInt = get "prim_geInt"
, prim_ltInt = get "prim_ltInt"
, prim_leInt = get "prim_leInt"
, prim_newRefInt = get "prim_newRefInt"
, prim_readRefInt = get "prim_readRefInt"
, prim_writeRefInt = get "prim_writeRefInt"
, prim_newRefInt_T2 = get "prim_newRefInt_T2"
, prim_readRefInt_T2 = get "prim_readRefInt_T2"
, prim_writeRefInt_T2 = get "prim_writeRefInt_T2"
, prim_newVectorInt = get "prim_newVectorInt"
, prim_readVectorInt = get "prim_readVectorInt"
, prim_writeVectorInt = get "prim_writeVectorInt"
, prim_sliceVectorInt = get "prim_sliceVectorInt"
, prim_rateOfSeries = get "prim_rateOfSeries"
, prim_loop = get "prim_loop"
, prim_guard = get "prim_guard"
, prim_nextInt = get "prim_nextInt"
, prim_nextInt_T2 = get "prim_nextInt_T2" }
return $ Just (table, bs)
| otherwise
= return Nothing
makeSelectors
:: G.Var
-> [String]
-> UniqSM ([G.CoreBind], [(String, (G.CoreExpr, G.Type))])
makeSelectors v strs
= do
(bs, xts) <- liftM unzip
$ mapM (makeSelector v) strs
return $ (bs, zip strs xts)
makeSelector
:: G.Var
-> String
-> UniqSM (G.CoreBind, (G.CoreExpr, G.Type))
makeSelector v strField
| t <- G.varType v
, Just tc <- G.tyConAppTyCon_maybe t
, G.isAlgTyCon tc
, G.DataTyCon [dc] False <- G.algTyConRhs tc
, labels <- G.dataConFieldLabels dc
, Just field <- find (\n -> stringOfName n == strField) labels
= makeSelector' dc field (G.Var v) (G.varType v)
| otherwise
= error $ "repa-plugin.makeSelector: can't find primitive named " ++ strField
makeSelector'
:: G.DataCon
-> G.FieldLabel
-> G.CoreExpr
-> G.Type
-> UniqSM (G.CoreBind, (G.CoreExpr, G.Type))
makeSelector' dc labelWanted xTable tTable
= do
(bsAll, vWanted) <- makeFieldBinders dc labelWanted
let tResult = G.dataConFieldType dc labelWanted
vPrim <- newDummyExportedVar (stringOfName labelWanted) tResult
let bPrim = G.NonRec vPrim
$ G.mkWildCase xTable tTable tResult
[ (G.DataAlt dc, bsAll, G.Var vWanted)]
return (bPrim, (G.Var vPrim, tResult))
makeFieldBinders
:: G.DataCon
-> G.FieldLabel
-> UniqSM ([G.Var], G.Var)
makeFieldBinders dc labelWanted
= do let tWanted = G.dataConFieldType dc labelWanted
vWanted <- newDummyVar "wanted" tWanted
let bsAll = go vWanted (G.dataConFieldLabels dc)
return (bsAll, vWanted)
where go _ [] = []
go vWanted (l:ls)
| l == labelWanted
= vWanted
: go vWanted ls
| otherwise
= (G.mkWildValBinder $ G.dataConFieldType dc l)
: go vWanted ls
stringOfName :: Name.Name -> String
stringOfName name
= Occ.occNameString $ Name.nameOccName name