module System.Plugins.Safe
  (Extension,
   Arg,
   Symbol,
   LoadStatus (..),
   loadOneValue)
  where

import Data.Char
import System.Plugins hiding (Module, loadModule)
import System.Plugins.Utils (Arg)
import System.FilePath
import System.Directory
import System.Unix.Directory
import Language.Haskell.Exts
import Text.Printf

capitalize :: String -> String
capitalize [] = []
capitalize (x:xs) = toUpper x: xs

parseMode :: String -> [Extension] -> ParseMode
parseMode name exts = defaultParseMode {parseFilename = name, extensions = exts}

setModuleName :: String -> Module -> Module
setModuleName name (Module loc _ pragmas warns exports imports decls) =
                    Module loc (ModuleName name) pragmas warns exports imports decls

getModuleName :: Module -> String
getModuleName (Module _ (ModuleName name) _ _ _ _ _) = name

fromModuleName :: ModuleName -> String
fromModuleName (ModuleName s) = s

loadModule :: [Extension] -> FilePath -> IO Module
loadModule exts sourcePath = do
    let name = takeBaseName sourcePath
    result <- parseFileWithMode (parseMode name exts) sourcePath
    case result of
      ParseOk mod -> return $ setModuleName (capitalize name) mod
      ParseFailed loc err -> fail $ errMsg loc err
  where
    errMsg loc err = printf "Parse error in %s, line %d, col. %d: %s"
                            (srcFilename loc)
                            (srcLine loc)
                            (srcColumn loc)
                            err

fixModule :: [Extension] -> [String] -> [String] -> String -> Module -> Module
fixModule exts forcedImports allowedImports symbol (Module loc name _ _ _ imports decls) =
    Module loc name pragmas Nothing exports fixedImports safeDecls
  where
    pragmas | null exts = []
            | otherwise = [LanguagePragma zeroLoc $ map (Ident . show) exts]
    zeroLoc = SrcLoc (fromModuleName name) 0 0
    fixedImports = filter isAllowed imports ++ forcedImportsDecls
    isAllowed decl = fromModuleName (importModule decl) `elem` allowedImports
    forcedImportsDecls = map mkImportDecl forcedImports
    mkImportDecl name = ImportDecl zeroLoc (ModuleName name) False False Nothing Nothing Nothing
    exports = Just [EVar (UnQual (Ident symbol))]
    safeDecls = filter isSafe decls
    isSafe (ForImp {}) = False
    isSafe (ForExp {}) = False
    isSafe _           = True

writeModule :: FilePath -> Module -> IO ()
writeModule path mod = do
  let src = prettyPrint mod
--   putStrLn src
  writeFile path src

withTemporaryDirectory' _ f = f "."

-- | Load one specified symbol from Haskell source file.
-- That source will be:
-- 
--   * Forced to use specified language extensions;
--
--   * Forced to import specified modules;
--
--   * Allowed to import specified set of modules;
--
--   * Forbidden to import any other modules;
--
--   * Forbidden to use any FFI declarations.
-- 
-- Unsafe declarations will be simply removed from module.
--
-- WARNING: source file name should start with capital letter.
loadOneValue ::
                [Arg]        -- ^ Any command-line arguments for compiler
             -> [FilePath]   -- ^ Include paths
             -> FilePath     -- ^ Source file name
             -> [Extension]  -- ^ Language extensions to enable
             -> [String]     -- ^ Force this modules to be imported by plugin
             -> [String]     -- ^ Allow to import this modules
             -> Symbol       -- ^ Symbol to load
             -> IO (LoadStatus a)
loadOneValue args paths sourcePath exts forcedImports allowedImports symbol = do
  let name = takeFileName sourcePath
  mod <- loadModule exts sourcePath
  pwd <- getCurrentDirectory
  res <- withTemporaryDirectory "safe-plugin" $ \dir -> do
            setCurrentDirectory dir
            let newPath = dir </> name
            writeModule newPath (fixModule exts forcedImports allowedImports symbol mod)
            mst <- makeAll newPath args
            case mst of
              MakeFailure errs -> return $ LoadFailure errs
              MakeSuccess _ obj -> load obj paths [] symbol
  setCurrentDirectory pwd
  return res