module Data.Array.Repa.Plugin.Primitives
        ( Primitives (..)
        , slurpPrimitives)
where
import Data.Array.Repa.Plugin.ToGHC.Var
import Data.List
import Data.Maybe
import Control.Monad

import qualified HscTypes       as G
import qualified CoreSyn        as G
import qualified MkCore         as G
import qualified DataCon        as G
import qualified TyCon          as G
import qualified Type           as G
import qualified Var            as G
import qualified OccName        as Occ
import qualified Name           as Name

import UniqSupply               as G
import qualified UniqSet        as US


-------------------------------------------------------------------------------
-- | Table of GHC core expressions to use to invoke the primitives
--   needed by the lowering transform.
data Primitives
        = Primitives
        { prim_Series           :: !G.Type
        , prim_Vector           :: !G.Type
        , prim_Ref              :: !G.Type

          -- Arith Int
        , prim_addInt           :: (G.CoreExpr, G.Type)
        , prim_subInt           :: (G.CoreExpr, G.Type)
        , prim_mulInt           :: (G.CoreExpr, G.Type)
        , prim_divInt           :: (G.CoreExpr, G.Type)
        , prim_modInt           :: (G.CoreExpr, G.Type)
        , prim_remInt           :: (G.CoreExpr, G.Type)

          -- Eq Int
        , prim_eqInt            :: (G.CoreExpr, G.Type)
        , prim_neqInt           :: (G.CoreExpr, G.Type)
        , prim_gtInt            :: (G.CoreExpr, G.Type)
        , prim_geInt            :: (G.CoreExpr, G.Type)
        , prim_ltInt            :: (G.CoreExpr, G.Type)
        , prim_leInt            :: (G.CoreExpr, G.Type)

          -- Ref Int
        , prim_newRefInt        :: (G.CoreExpr, G.Type)
        , prim_readRefInt       :: (G.CoreExpr, G.Type)
        , prim_writeRefInt      :: (G.CoreExpr, G.Type)
        , prim_newRefInt_T2     :: (G.CoreExpr, G.Type)
        , prim_readRefInt_T2    :: (G.CoreExpr, G.Type)
        , prim_writeRefInt_T2   :: (G.CoreExpr, G.Type)

          -- Vector Int
        , prim_newVectorInt     :: (G.CoreExpr, G.Type)
        , prim_readVectorInt    :: (G.CoreExpr, G.Type)
        , prim_writeVectorInt   :: (G.CoreExpr, G.Type)
        , prim_sliceVectorInt   :: (G.CoreExpr, G.Type)

          -- Loop
        , prim_loop             :: (G.CoreExpr, G.Type)
        , prim_guard            :: (G.CoreExpr, G.Type)
        , prim_rateOfSeries     :: (G.CoreExpr, G.Type)
        , prim_nextInt          :: (G.CoreExpr, G.Type)
        , prim_nextInt_T2       :: (G.CoreExpr, G.Type)
        }


-- | Names of all the primitive types.
--   These should match the field names of `Primitives` above.
_primitive_types
 =      [ "Series"
        , "Vector"
        , "Ref" ]


-- | Names of all the primitive operators.
--   These should match the field names of `Primitives` above.
primitive_ops
 =      -- Arith Int
        [ "prim_addInt"
        , "prim_subInt"
        , "prim_mulInt"
        , "prim_divInt"
        , "prim_modInt"
        , "prim_remInt"

        -- Eq Int
        , "prim_eqInt"
        , "prim_neqInt"
        , "prim_gtInt"
        , "prim_geInt"
        , "prim_ltInt"
        , "prim_leInt"

        -- Ref Int
        , "prim_newRefInt"
        , "prim_readRefInt"
        , "prim_writeRefInt"
        -- Ref (Int,Int)
        , "prim_newRefInt_T2"
        , "prim_readRefInt_T2"
        , "prim_writeRefInt_T2"

        -- Vector Int
        , "prim_newVectorInt"
        , "prim_readVectorInt"
        , "prim_writeVectorInt"
        , "prim_sliceVectorInt"

        -- Loop
        , "prim_loop"
        , "prim_guard"
        , "prim_rateOfSeries"
        , "prim_nextInt" 
        , "prim_nextInt_T2" ]


-------------------------------------------------------------------------------
-- | Try to slurp the primitive table from a GHC module.
--
--   The table should be in a top-level binding named "repa_primitives".
--   If we find it, then we add more top-level functions to the module 
--   that select the individual primitives, then build a table of expressions
--   that can be used to access them.
--
slurpPrimitives 
        :: G.ModGuts 
        -> UniqSM (Maybe (Primitives, G.ModGuts))

slurpPrimitives guts
 | Just vTable  <- listToMaybe 
                $  mapMaybe findTableFromTopBind 
                $  G.mg_binds guts
 = do   
        Just (prims, bsMoar) <- makeTable vTable

        let hackedGuts          
                = guts  
                { G.mg_binds    
                        = insertAfterTable bsMoar 
                        $ G.mg_binds guts
                
                , G.mg_used_names       
                        = US.addListToUniqSet (G.mg_used_names guts)
                        $ [G.varName b | G.NonRec b _ <- bsMoar ]}

        return  $ Just (prims, hackedGuts)

 | otherwise
 =      return Nothing
        

-------------------------------------------------------------------------------
-- | Try to find the primitive table in this top level binding.
findTableFromTopBind :: G.CoreBind -> Maybe G.Var
findTableFromTopBind bnd
 = case bnd of
        G.Rec{}         -> Nothing
        G.NonRec b _    -> findTableFromBinding b


-- | Try to find the primitive table in this top level binding.
--   It needs to be named "repa_primitives"
findTableFromBinding :: G.CoreBndr -> Maybe G.Var
findTableFromBinding b
        | strName      <- Occ.occNameString 
                       $  Name.nameOccName 
                       $  G.varName b
        , strName == "repa_primitives"
        = Just b

        | otherwise
        = Nothing


-------------------------------------------------------------------------------
-- | Insert some top-level bindings after the primitive table.
insertAfterTable :: [G.CoreBind] -> [G.CoreBind] -> [G.CoreBind]
insertAfterTable bsMore bs
 = case bs of
        []                      
         -> bs
        
        bb@G.Rec{} : bs'         
         -> bb : insertAfterTable bsMore bs'
        
        bb@(G.NonRec b _) : bs'
         |  isJust $ findTableFromBinding b
         -> bb : bsMore ++ bs'

         | otherwise
         -> bb : insertAfterTable bsMore bs'


-------------------------------------------------------------------------------
-- | Create top-level projection functions based on the primitive table
--   attached to this variable.
makeTable 
        :: G.Var 
        -> UniqSM (Maybe (Primitives, [G.CoreBind]))

makeTable v
 | t                      <- G.varType v
 , Just tc                <- G.tyConAppTyCon_maybe t
 , G.isAlgTyCon tc
 , G.DataTyCon [dc] False <- G.algTyConRhs tc
 = do
        let labels
                = G.dataConFieldLabels dc

        -- Load types from their proxy fields.
        let Just tySeries   
                = liftM (G.dataConFieldType dc)
                $ find (\n -> stringOfName n ==  "prim_Series") labels

        let Just tyVector   
                = liftM (G.dataConFieldType dc)
                $ find (\n -> stringOfName n ==  "prim_Vector") labels

        let Just tyRef
                = liftM (G.dataConFieldType dc)
                $ find (\n -> stringOfName n ==  "prim_Ref") labels

        -- Build table of selectors for all the operators.
        (bs, selectors)     <- makeSelectors v primitive_ops
        let get name
                = let Just r    = lookup name selectors
                  in  r

        let table      
                = Primitives
                { prim_Series           = tySeries
                , prim_Vector           = tyVector
                , prim_Ref              = tyRef

                -- Arith Int
                , prim_addInt           = get "prim_addInt"
                , prim_subInt           = get "prim_subInt"
                , prim_mulInt           = get "prim_mulInt"
                , prim_divInt           = get "prim_divInt"
                , prim_modInt           = get "prim_modInt"
                , prim_remInt           = get "prim_remInt"

                -- Eq Int
                , prim_eqInt            = get "prim_eqInt"
                , prim_neqInt           = get "prim_neqInt"
                , prim_gtInt            = get "prim_gtInt"
                , prim_geInt            = get "prim_geInt"
                , prim_ltInt            = get "prim_ltInt"
                , prim_leInt            = get "prim_leInt"

                -- Ref Int
                , prim_newRefInt        = get "prim_newRefInt"
                , prim_readRefInt       = get "prim_readRefInt"
                , prim_writeRefInt      = get "prim_writeRefInt"
                -- Ref (Int,Int)
                , prim_newRefInt_T2     = get "prim_newRefInt_T2"
                , prim_readRefInt_T2    = get "prim_readRefInt_T2"
                , prim_writeRefInt_T2   = get "prim_writeRefInt_T2"

                -- Vector Int
                , prim_newVectorInt     = get "prim_newVectorInt"
                , prim_readVectorInt    = get "prim_readVectorInt"
                , prim_writeVectorInt   = get "prim_writeVectorInt"
                , prim_sliceVectorInt   = get "prim_sliceVectorInt"

                -- Loop
                , prim_rateOfSeries     = get "prim_rateOfSeries" 
                , prim_loop             = get "prim_loop"
                , prim_guard            = get "prim_guard"
                , prim_nextInt          = get "prim_nextInt"
                , prim_nextInt_T2       = get "prim_nextInt_T2" }


        return $ Just (table, bs)

 | otherwise
 = return Nothing


-------------------------------------------------------------------------------
-- | Make the selector table.
makeSelectors
        :: G.Var                -- ^ Core variable bound to our primitive table.
        -> [String]             -- ^ Names of all the primtiives.
        -> UniqSM ([G.CoreBind], [(String, (G.CoreExpr, G.Type))])

makeSelectors v strs
 = do
        (bs, xts)       <- liftM unzip
                        $  mapM (makeSelector v) strs

        return  $ (bs, zip strs xts)


-------------------------------------------------------------------------------
-- | Build a CoreExpr that produces the primtive with the given name.
makeSelector
        :: G.Var                -- ^ Core variable bound to our primtiive table.
        -> String               -- ^ Name of the primitive we want.
        -> UniqSM (G.CoreBind, (G.CoreExpr, G.Type))

makeSelector v strField
 | t                      <- G.varType v
 , Just tc                <- G.tyConAppTyCon_maybe t
 , G.isAlgTyCon tc
 , G.DataTyCon [dc] False <- G.algTyConRhs tc
 , labels                 <- G.dataConFieldLabels dc
 , Just field             <- find (\n -> stringOfName n == strField) labels
 = makeSelector' dc field (G.Var v) (G.varType v)

 | otherwise
 = error $ "repa-plugin.makeSelector: can't find primitive named " ++ strField


makeSelector'
        :: G.DataCon            -- ^ Data constructor for the primitive table.
        -> G.FieldLabel         -- ^ Name of the field to project out.
        -> G.CoreExpr           -- ^ Expression to produce the table.
        -> G.Type               -- ^ Type of the table.
        -> UniqSM (G.CoreBind, (G.CoreExpr, G.Type))

makeSelector' dc labelWanted xTable tTable
 = do   
        -- Make binders to match all fields,
        --      including one for the field we want.
        (bsAll, vWanted) <- makeFieldBinders dc labelWanted

        -- The type of the wanted field.
        let tResult     =  G.dataConFieldType dc labelWanted

        -- Top level name for this primitive.
        vPrim           <- newDummyExportedVar (stringOfName labelWanted) tResult
        
        let bPrim       = G.NonRec vPrim 
                        $ G.mkWildCase xTable tTable tResult
                                [ (G.DataAlt dc, bsAll, G.Var vWanted)]

        return  (bPrim, (G.Var vPrim, tResult))
                

-- | Make a sequence of binders 
makeFieldBinders 
        :: G.DataCon               -- ^ Data constructor for the primtiive table.
        -> G.FieldLabel            -- ^ The field we want to project out.
        -> UniqSM ([G.Var], G.Var) -- ^ All binders, and the one for our desired field.

makeFieldBinders dc labelWanted
 = do   let tWanted =  G.dataConFieldType dc labelWanted
        vWanted     <- newDummyVar "wanted" tWanted
        let bsAll   =  go vWanted (G.dataConFieldLabels dc)
        return  (bsAll, vWanted)

 where  go _       []   = []
        go vWanted (l:ls)
         | l == labelWanted 
         = vWanted
                : go vWanted ls

         | otherwise        
         = (G.mkWildValBinder $ G.dataConFieldType dc l)
                : go vWanted ls


-- Utils ----------------------------------------------------------------------
-- | Convert a GHC name to a string
stringOfName :: Name.Name -> String
stringOfName name
 = Occ.occNameString $ Name.nameOccName name