module System.Plugin (
-- * Methods    
    pdynload
    ) where

import Control.Exception
import Control.Monad (when, liftM)
import Data.IORef
import Data.List (partition)
import GHC.Paths (libdir, ghc)
import MonadUtils (liftIO)
import System.Directory
import System.Exit
import System.IO
import System.Process
import System.Time
import Unsafe.Coerce
import qualified DynFlags
import qualified Exception
import qualified GHC
import qualified HscTypes
import qualified IOEnv
import qualified Linker
import qualified LoadIface
import qualified Maybes
import qualified Module
import qualified Name
import qualified OccName
import qualified Outputable
import qualified PackageConfig
import qualified Packages
import qualified SrcLoc
import qualified TcRnTypes
import qualified UniqSupply
import qualified Unique

-- | The status of type check.
data TypeCheckStatus = TypeMatch
                     | TypeMismatch String
                       deriving (Show, Eq, Ord)

-- | Polymorphic dynamic loading.
--
-- Resolves the specified symbol to any given type.  This means linking the package
-- containing it if it is not already linked, extracting the value of that symbol,
-- and returning that value. 
--
-- Here has simplest demo for test:
--
-- > module Main where
-- > 
-- > import System.Plugin
-- > import Unsafe.Coerce
-- > 
-- > main = do
-- >   val <- pdynload ("Prelude", "reverse") ("", "String -> String")
-- >   let str = case val of
-- >               Just v  -> (unsafeCoerce v :: String -> String) "hello"
-- >               Nothing -> "Load failed."
-- >   print str
--
-- Because 'pdynload' check type at runtime, so don't afraid 'unsafeCoerce',
-- it is perfect safety.
-- 
pdynload :: (String, String)
         -- ^ A tuple (@symbolModule@, @symbol@), specifying a symbol in module
         --   
         --   @symbolModule@ is a fully-qualified module name, ie @\"Data.List\"@
         --   
         --   @symbol@ is an unqualified symbol name, ie @\"reverse\"@.
         -> (String, String) 
         -- ^ A tuple (@typModule@, @typ@), specifying a type in module
         --   
         --   @typModule@ is a fully-qualified module name, ie @\"Prelude\"@ ,
         --   you can empty this string if type define in 'Prelude'.
         --   
         --   @typ@ is an unqualified type name, ie @\"String -> String\"@.
         -> IO (Maybe a)
         -- ^ If the specified symbol is found, 'Just' its value. Otherwise, 'Nothing'.
pdynload (symbolModule, symbol) (typModule, typ) = do
  putStr "* Check type ... "
  status <- typeCheck (symbolModule, symbol) (typModule, typ)
  case status of
    TypeMismatch error -> do
      putStrLn "failed : "
      putStrLn error
      return Nothing
    TypeMatch -> do
      putStrLn "done."
      GHC.defaultErrorHandler DynFlags.defaultDynFlags $ 
        GHC.runGhc (Just libdir) $ do
          -- Update the DynFlags of current session.
          flags <- GHC.getSessionDynFlags
          GHC.setSessionDynFlags flags
          -- Initialise package information.
          (flags, _) <- liftIO $ Packages.initPackages flags
          -- Initialise the dynamic linker.
          liftIO $ Linker.initDynLinker flags

          -- Get package name for loading.
          pName <- liftIO $ do
                    -- Search package in ghc database.
                    putStrLn $ "* Lookup package of module " ++ symbolModule
                    lookupPackageName flags symbolModule

          -- Try linking.
          case pName of
            -- Try linking symbol when found match package name.
            Just pn -> 
                do
                  let pId = Module.stringToPackageId pn
                  hscEnv <- GHC.getSession
                  -- Because symbol perhaps re-export from external module.
                  -- So we need parse symbol to find define location.
                  name <- liftIO $ parseSymbol (pn, symbolModule, symbol) hscEnv flags
                  case name of
                    Just n -> load n pId flags
                    Nothing -> return Nothing
            Nothing -> return Nothing
              
-- | Type check at runtime.
--
-- Instead of using any form of user-land dynamic types for runtime type checking,
-- We simply invoke the regular type checker on the test module at runtime.
-- This approach has the advantage that it is entirely independent of any extensions
-- of type type system supported by the underlying implementation.            
-- It does not require any extension, but it also does not inhibit the use of any features            
-- of the type system.
--
-- Example
--
-- > typecheck ("Data.List", "reverse") ("Prelude", "String -> String")            
--
-- will generate test module like below:            
--
-- > module TypeCheck where            
-- >             
-- > import qualified Data.List            
-- > import Prelude            
-- >            
-- > typecheck = Data.List.reverse :: String -> String            
--
-- Then call \"ghc -e typecheck TypeCheck.hs\" for type-check.            
-- If evaluate successful, then return 'TypeMatch', otherwise return 'TypeMismatch' with error.
--
typeCheck :: (String, String)          
          -- ^ A tuple (@symbolModule@, @symbol@), specifying a symbol in module
          --   
          --   @symbolModule@ is a fully-qualified module name, ie @\"Data.List\"@
          --   
          --   @symbol@ is an unqualified symbol name, ie @\"reverse\"@.
          -> (String, String) 
          -- ^ A tuple (@typModule@, @typ@), specifying a type in module
          --   
          --   @typModule@ is a fully-qualified module name, ie @\"Prelude\"@ ,
          --   you can empty this string if type define in Prelude.
          --   
          --   @typ@ is an unqualified type name, ie @\"String -> String\"@.
          -> IO TypeCheckStatus
          -- ^ If type match return 'TypeMatch', otherwise return 'TypeMismatch' with error.
typeCheck (symbolModule, symbol) (typModule, typ) = do 
  -- Get unique id.
  uniqueId <- getPicoseconds 

  -- Initialise.
  let tempDir           = "/tmp/" 
      tempModule        = "TypeCheck" ++ show uniqueId
      tempFile          = tempDir ++ tempModule ++ ".hs"
      errFile           = tempDir ++ tempModule ++ ".log"
      checkExpression   = "typecheck"

  -- Try compile temp module.
  -- Redirect error to errFile if compile failed.
  result <- bracket (openFile errFile WriteMode) hClose $ \errHandle -> do
              -- Build source code.
              let sourceCode = 
                      -- Module name.
                      "module " ++ tempModule ++ " where" ++ "\n"
                      -- Import symbol module qualified.
                      ++ "import qualified " ++ symbolModule ++ "\n"
                      -- Import type module qualified.
                      ++ (if null typModule 
                          then "import Prelude\n"
                          else "import " ++ typModule ++ "\n")
                      -- typecheck = SymbolModule.symbol :: TypeModule.Type
                      ++ checkExpression ++ " = " ++ (symbolModule ++ "." ++ symbol) -- symbol expression
                      ++ " :: " ++ typ                                     -- type expression

              -- Write source code to temp haskell file.
              handle <- openFile tempFile WriteMode
              hWrite handle sourceCode
                               
              -- Evaluate 'typecheck' in temp haskell file.
              let ghcOpts = ["-e", checkExpression, tempFile]
              ghcProc <- runProcess ghc ghcOpts (Just tempDir) Nothing Nothing Nothing (Just errHandle)

              -- If type mismatch or other runtime error, return ExitFailure. 
              waitForProcess ghcProc
  
  -- Get status.
  status <- if result == ExitSuccess
               -- Type match if compile haskell file successfully.
               then return TypeMatch
               -- Otherwise read error information.
               else liftM TypeMismatch $ readFile errFile
  
  -- Clean temporary and error file.
  tryRemoveFile tempFile
  tryRemoveFile errFile

  -- Return status.
  return status

-- | Internal load function for pdynload.          
load :: (GHC.GhcMonad m) 
             => (String, String, String)
             -> Module.PackageId
             -> GHC.DynFlags
             -> m (Maybe a)
load (packageName, moduleName, symbolName) packageId flags =
    Exception.ghandle
        (\(GHC.CmdLineError _) -> do
           -- Catch package error.
           liftIO $ putStrLn $ "Unknown package " ++ packageName
           return Nothing)
        (do
           -- Link exactly the specified packages, and their dependents 
           -- (unless of course they are already linked). 
           -- The dependents are linked automatically, 
           -- and it doesn't matter what order you specify the input packages.                                      
           liftIO $ Linker.linkPackages flags [packageId]

           Exception.ghandle
             (\(GHC.ProgramError string) -> do
                if hasPrefix string "Failed to load interface "
                   -- Catch module error
                   then liftIO $ putStrLn $ "Unknown module '" ++ moduleName ++ "'"
                                          ++ " in package '" ++ packageName ++ "'"
                   -- Catch symbol error.
                   else liftIO $ putStrLn $ "Unknown symbol '" ++ symbolName ++ "'"
                                          ++ " in module '" ++ moduleName ++ "'"
                                          ++ " in package '" ++ packageName ++ "'"
                return Nothing)
             (do
               liftIO $ putStrLn $ "* Linking " ++ packageName ++ ":" ++ moduleName ++ "." ++ symbolName ++ " ..."
               -- Get current session.
               session <- GHC.getSession
               -- Create a name which definitely originates in the given module.
               let name = Name.mkExternalName
                            (Unique.mkBuiltinUnique 0)
                            (Module.mkModule packageId
                                 (Module.mkModuleName moduleName))
                            (OccName.mkVarOcc symbolName)
                            SrcLoc.noSrcSpan

               -- Get the HValue associated with the given name.
               -- May cause loading the module that contains the name.
               result <- liftIO $ Linker.getHValue session name
               -- Use unsafeCoerce convert to user specify type.
               return $ Just $ unsafeCoerce result))

-- | Lookup package name from ghc database.
lookupPackageName :: DynFlags.DynFlags -> String -> IO (Maybe String)
lookupPackageName flags moduleName 
    -- Return Nothing when no package found.
    | packageNum == 0 
        = do
        putStrLn $ "Can't found module " ++ show moduleName
        return Nothing
    | packageNum == 1 
        = do
        let (packageConfig, isExpose) = head packages
            pName = packageConfigIdString packageConfig
        if isExpose 
           -- Return package name when package expose module.
           then return $ Just pName
           -- Return Nothing if module hide in package.
           else do
             putStrLn $ "Module " ++ show moduleName ++ " hide in package " ++ pName
             return Nothing
    -- Return Nothing if no package expose module.
    | null exposePackages 
        = do
        putStrLn $ "Can't found module " ++ show moduleName
        return Nothing
    -- Return first match package when found module in multiple packages.  
    | otherwise  
        = do
        let firstPackageIdString = packageConfigIdString $ fst $ head exposePackages
        putStrLn $ "Module " ++ show moduleName ++ " expose in multiple packages :"
                     ++ concatMap (\ (packageConfig, _) -> 
                                       "\n   " ++ packageConfigIdString packageConfig) exposePackages
        putStrLn $ "# Use package '" ++ firstPackageIdString ++ "' (Maybe you need specify package name)"
        return $ Just firstPackageIdString
    where packages              -- search package that *contain* module.
              = Packages.lookupModuleInAllPackages flags (Module.mkModuleName moduleName)
          packageNum            -- length of package list
              = length packages
          exposePackages        -- filter package that *export* module
              = filter snd packages

-- | Parse symbol whether defined in current module.
-- If symbol is re-export other module, parse recursively, 
-- until found the define location of symbol.
parseSymbol :: (String, String, String)
            -> HscTypes.HscEnv
            -> GHC.DynFlags
            -> IO (Maybe (String, String, String))
parseSymbol (packageName, moduleName, symbolName) hscEnv flags = do
    putStrLn $ "* Parse " ++ packageName ++ ":" ++ moduleName ++ "." ++ symbolName ++ " ..."
    -- Build unique supply to build environment.
    uniqueSupply <- UniqSupply.mkSplitUniqSupply 'a'
    uniqueSupplyIORef <- newIORef uniqueSupply
    -- Initialise.
    let packageId = Module.stringToPackageId packageName
        module' = Module.mkModule packageId $ Module.mkModuleName moduleName
        environment = TcRnTypes.Env {
                        TcRnTypes.env_top = hscEnv,
                        TcRnTypes.env_us = uniqueSupplyIORef,
                        TcRnTypes.env_gbl = (),
                        TcRnTypes.env_lcl = ()}
    -- Find and read interface file. 
    iface <- IOEnv.runIOEnv environment
                   $ LoadIface.findAndReadIface Outputable.empty module' False
    case iface of
      -- Return Nothing if can't found interface file.
      Maybes.Failed _ -> do
        putStrLn $ "Can't found interface file of " ++ packageName ++ ":" ++ moduleName ++ "." ++ symbolName
        return Nothing
      -- Parse symbol.
      Maybes.Succeeded (moduleInterface, hiFile) -> do
        putStrLn $ "Scan interface file " ++ hiFile ++ " ..."
        -- Export list of current module
        let ifaceExport = HscTypes.mi_exports moduleInterface
            -- [(ModuleName, [type])]
            exports = map (\ (mod, items) -> 
                                 (Module.moduleNameString $ Module.moduleName mod
                                 ,concatMap (\item -> 
                                                 case item of
                                                   HscTypes.Avail name -> [OccName.occNameString name]
                                                   HscTypes.AvailTC _ list -> 
                                                       map OccName.occNameString list
                                            ) items)
                            ) ifaceExport
            -- Partition current module and external module.
            (currentExports, otherExports) = partition (\ (mName, _) -> mName == moduleName) exports
        case findSymbolInExportList currentExports symbolName of
          -- Return current module if found symbol in export list of current module.
          Just _ -> do
                 putStrLn $ "'" ++ symbolName ++ "' defined in " ++ packageName ++ ":" ++ moduleName
                 return $ Just (packageName, moduleName, symbolName)
          Nothing -> 
              -- Parse recursively if symbol is re-export from external module.
              case findSymbolInExportList otherExports symbolName of
                Just mn -> do
                  putStrLn $ "'" ++ symbolName ++ "' is re-export from module " ++ mn
                  -- Lookup new package of external module.
                  newPackageName <- do
                            putStrLn $ "* Lookup package of module " ++ mn
                            lookupPackageName flags mn
                  case newPackageName of
                    -- Parse symbol in new package.
                    Just npn -> parseSymbol (npn, mn, symbolName) hscEnv flags
                    -- Return Nothing if package not found.
                    Nothing -> return Nothing
                -- Return Nothing if can't found symbol in interface file.
                -- If reach this, interface file is incorrect.
                Nothing -> do
                  putStrLn $ "Can't found symbol " ++ symbolName ++ " in " ++ hiFile
                  return Nothing
                 
-- | Get PackageConfig id string.
packageConfigIdString :: Packages.PackageConfig -> String
packageConfigIdString = Module.packageIdString . PackageConfig.packageConfigId

-- | Has prefix?
hasPrefix :: String -> String -> Bool
hasPrefix string prefix = 
    take (length prefix) string == prefix

-- | Find symbol in export list of module.
-- Return module name when found symbol in export list. 
-- Otherwise return Nothing.
findSymbolInExportList :: Eq b => [(a, [b])] -> b -> Maybe a
findSymbolInExportList [] _ = Nothing
findSymbolInExportList ((moduleName, symList) :xs) sym
    | sym `elem` symList 
        = Just moduleName
    | otherwise 
        = findSymbolInExportList xs sym

-- | Get picoseconds for build unique haskell file.
getPicoseconds :: IO Integer
getPicoseconds = do
  (TOD second picosecond) <- getClockTime
  return (second * (10 ^ 12) + picosecond)

-- | Write file.
hWrite :: Handle -> String -> IO ()
hWrite hdl src = 
    hPutStr hdl src >> hClose hdl >> return ()

-- | Try remove file if file exist.
tryRemoveFile :: FilePath -> IO ()
tryRemoveFile filepath = do
  isExist <- doesFileExist filepath
  when isExist $ removeFile filepath