{-# LANGUAGE ConstraintKinds, FlexibleContexts, RankNTypes, TupleSections, TypeFamilies, LambdaCase #-}

module Language.Haskell.Tools.Refactor.Builtin.OrganizeExtensions
  ( module Language.Haskell.Tools.Refactor.Builtin.OrganizeExtensions
  , module Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.ExtMonad
  ) where

import Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.ExtMonad
import Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.TraverseAST
import Language.Haskell.Tools.Refactor.Builtin.ExtensionOrganizer.Utils.SupportedExtensions (unregularExts, isSupported, fullyHandledExtensions)

import Language.Haskell.Tools.Refactor hiding (LambdaCase)
import Language.Haskell.Tools.Refactor.Utils.Extensions (expandExtension)

import GHC (Ghc(..))

import Control.Reference
import Data.Char (isAlpha)
import Data.Function (on)
import Data.List
import qualified Data.Map.Strict as SMap (keys, empty)

-- NOTE: When working on the entire AST, we should build a monad,
--       that will will avoid unnecessary checks.
--       For example if it already found a record wildcard, it won't check again

--       Pretty easy now. Chcek wheter it is already in the ExtMap.

organizeExtensionsRefactoring :: RefactoringChoice
organizeExtensionsRefactoring = ModuleRefactoring "OrganizeExtensions" (localRefactoring organizeExtensions)

projectOrganizeExtensionsRefactoring :: RefactoringChoice
projectOrganizeExtensionsRefactoring = ProjectRefactoring "ProjectOrganizeExtensions" projectOrganizeExtensions

projectOrganizeExtensions :: ProjectRefactoring
projectOrganizeExtensions =
  mapM (\(k, m) -> ContentChanged . (k,) <$> localRefactoringRes id m (organizeExtensions m))

tryOut :: String -> String -> IO ()
tryOut = tryRefactor (localRefactoring . const organizeExtensions)

organizeExtensions :: LocalRefactoring
organizeExtensions moduleAST = do
  exts <- liftGhc $ reduceExtensions moduleAST
  let isRedundant e = extName `notElem` foundExts && extName `elem` handledExts
        where extName = unregularExts (e ^. langExt)
      handledExts = map show fullyHandledExtensions
      foundExts = map show exts

  -- remove unused extensions (only those that are fully handled)
  filePragmas & annList & lpPragmas !~ filterListSt (not . isRedundant)
        -- remove empty {-# LANGUAGE #-} pragmas
    >=> filePragmas !~ filterListSt (\case LanguagePragma (AnnList []) -> False; _ -> True)
    $ moduleAST

-- | Reduces default extension list (keeps unsupported extensions)
reduceExtensions :: UnnamedModule -> Ghc [Extension]
reduceExtensions = \moduleAST -> do
  let expanded = expandDefaults moduleAST
      (xs, ys) = partition isSupported expanded
  xs' <- flip execStateT SMap.empty . flip runReaderT xs . traverseModule $ moduleAST
  return . sortBy (compare `on` show) . mergeInduced . nub $ (calcExts xs' ++ ys)

  where isLVar (LVar _) = True
        isLVar _        = False

        calcExts :: ExtMap -> [Extension]
        calcExts logRels
          | ks <- SMap.keys logRels
          , all isLVar ks
          = map (\(LVar x) -> x) . SMap.keys $ logRels
          | otherwise     = []

        rmInduced :: Extension -> [Extension] -> [Extension]
        rmInduced e = flip (\\) induced
          where induced = delete e $ expandExtension e

        mergeInduced :: [Extension] -> [Extension]
        mergeInduced exts = foldl (flip rmInduced) exts exts


-- | Collects extensions induced by the source code (with location info)
collectExtensions :: UnnamedModule -> Ghc ExtMap
collectExtensions moduleAST = do
  let expanded = expandDefaults moduleAST
  flip execStateT SMap.empty . flip runReaderT expanded . traverseModule $ moduleAST

-- | Collects default extension list, and expands each extension
expandDefaults :: UnnamedModule -> [Extension]
expandDefaults = nub . concatMap expandExtension . collectDefaultExtensions

-- | Collects extensions enabled by default
collectDefaultExtensions :: UnnamedModule -> [Extension]
collectDefaultExtensions = map toExt . getExtensions
  where
  getExtensions :: UnnamedModule -> [String]
  getExtensions = flip (^?) (filePragmas & annList & lpPragmas & annList & langExt)

toExt :: String -> Extension
toExt str = case map fst . reads . unregularExts . takeWhile isAlpha $ str of
              e:_ -> e
              []  -> error $ "Extension '" ++ takeWhile isAlpha str ++ "' is not known."