-- | Provides the plugins monadic envionment,
--   access to the environment and message printing capabilities.
module Control.Supermonad.Plugin.Environment
  ( -- * Polymonad Plugin Monad
    SupermonadPluginM
  , runSupermonadPlugin
  , runTcPlugin
    -- * Polymonad Plugin Environment Access
  , getBindClass, getReturnClass
  , getSupermonadModule
  , getGivenConstraints, getWantedConstraints
  , getInstEnvs
  , getBindInstances
  , getSupermonadFor
  , addTypeEqualities, addTypeEquality
  , addTyVarEqualities, addTyVarEquality
  , getTypeEqualities, getTyVarEqualities
  , whenNoResults
  , addWarning, displayWarnings
  , throwPluginError, throwPluginErrorSDoc, catchPluginError
    -- * Debug and Error Output
  , assert, assertM
  , printErr, printMsg, printObj, printWarn
  , printConstraints
  ) where

import Data.List ( groupBy )
import Data.Map ( Map )
import qualified Data.Map as M

import Control.Monad ( unless, forM_ )
import Control.Monad.Reader ( ReaderT, runReaderT, asks )
import Control.Monad.State  ( StateT , runStateT , gets, modify )
import Control.Monad.Except ( ExceptT, runExceptT, throwError, catchError )
import Control.Monad.Trans.Class ( lift )

import Class ( Class )
import Module ( Module )
import InstEnv ( InstEnvs, ClsInst )
import Type ( TyVar, Type )
import TyCon ( TyCon )
import TcRnTypes ( Ct )
import TcPluginM ( TcPluginM, tcPluginIO )
import qualified TcPluginM
import Outputable ( Outputable )
import SrcLoc ( srcSpanFileName_maybe )
import FastString ( unpackFS )
import qualified Outputable as O

import qualified Control.Supermonad.Plugin.Log as L
import Control.Supermonad.Plugin.Constraint
  ( GivenCt, WantedCt
  , constraintSourceLocation )
import Control.Supermonad.Plugin.Detect
  ( findSupermonadModule
  , findBindClass, findReturnClass
  , findInstancesInScope
  , bindClassName, returnClassName
  , findSupermonads
  , checkSupermonadInstances )

-- -----------------------------------------------------------------------------
-- Plugin Monad
-- -----------------------------------------------------------------------------

-- | The error type used as result if the plugin fails.
type SupermonadError = O.SDoc

-- | The plugin monad.
type SupermonadPluginM = ReaderT SupermonadPluginEnv 
                       ( StateT  SupermonadPluginState 
                       ( ExceptT SupermonadError TcPluginM
                       ) )

-- | The read-only environent of the plugin.
data SupermonadPluginEnv = SupermonadPluginEnv
  { smEnvSupermonadModule :: Module
  -- ^ The 'Control.Supermonad' module.
  , smEnvBindClass :: Class
  -- ^ The 'Bind' class.
  , smEnvReturnClass :: Class
  -- ^ The 'Return' class.
  , smEnvBindInstances :: [ClsInst]
  -- ^ Collection of all 'Bind' instances.
  , smEnvGivenConstraints  :: [GivenCt]
  -- ^ The given and derived constraints (all of them).
  , smEnvWantedConstraints :: [WantedCt]
  -- ^ The wanted constraints (all of them).
  , smEnvSupermonads :: Map TyCon (ClsInst, ClsInst)
  -- ^ The supermonads currently in scope. Associates the type constructor 
  --   of each supermonad with its 'Control.Supermonad.Bind' and 
  --   'Control.Supermonad.Return' instance.
  }

-- | The modifiable state of the plugin.
data SupermonadPluginState = SupermonadPluginState 
  { smStateTyVarEqualities :: [(Ct, TyVar, Type)]
  -- ^ Equalities between type variables and types that have been derived by the plugin.
  , smStateTypeEqualities :: [(Ct, Type, Type)]
  -- ^ Eqaulities between types that have been derived by the plugin.
  , smStateWarningQueue :: [(String, O.SDoc)]
  -- ^ A queue of warnings that are only displayed if no progress could be made.
  }

-- | Runs the given supermonad plugin solver within the type checker plugin 
--   monad.
runSupermonadPlugin 
  :: [GivenCt] -- ^ /Given/ and /derived/ constraints. 
  -> [WantedCt] -- ^ /Wanted/ constraints.
  -> SupermonadPluginM a -- ^ Plugin code to run.
  -> TcPluginM (Either SupermonadError a) -- ^ Either an error message or an actual plugin result.
runSupermonadPlugin givenCts wantedCts smM = do
  mSupermonadMdl <- findSupermonadModule
  mBindCls <- findBindClass
  mReturnCls <- findReturnClass
  -- Calculate the supermonads in scope and check for rogue bind and return instances.
  (smInsts, smErrors) <- case (mBindCls, mReturnCls) of
      (Just bindCls, Just returnCls) -> do
        (smInsts, smErrors) <- findSupermonads bindCls returnCls
        smCheckErrors <- checkSupermonadInstances bindCls returnCls
        return $ (smInsts, fmap snd smErrors ++ fmap snd smCheckErrors) 
      (_, _) -> return mempty
  -- Try to construct the environment or throw errors
  case (mSupermonadMdl, mBindCls, mReturnCls, smErrors) of
    (Right supermonadMdl, Just bindCls, Just returnCls, []) -> do
      let initState = SupermonadPluginState 
            { smStateTyVarEqualities = []
            , smStateTypeEqualities  = []
            , smStateWarningQueue    = [] 
            }
      bindInsts <- findInstancesInScope bindCls
      eResult <- runExceptT $ flip runStateT initState $ runReaderT smM $ SupermonadPluginEnv
        { smEnvSupermonadModule  = supermonadMdl
        , smEnvBindClass         = bindCls
        , smEnvReturnClass       = returnCls
        , smEnvBindInstances     = bindInsts
        , smEnvGivenConstraints  = givenCts
        , smEnvWantedConstraints = wantedCts
        , smEnvSupermonads       = smInsts
        }
      return $ case eResult of
        Left  err -> Left err
        Right (a, _res) -> Right a
    (Left mdlErrMsg, _, _, _) -> do
      let msg = "Could not find supermonad module:"
      L.printErr msg
      L.printErr $ L.sDocToStr mdlErrMsg
      return $ Left $ stringToSupermonadError msg O.$$ mdlErrMsg
    (_, Nothing, _, _) -> do
      let msg = "Could not find " ++ bindClassName ++ " class!"
      L.printErr msg
      return $ Left $ stringToSupermonadError msg
    (_, _, Nothing, _) -> do
      let msg = "Could not find " ++ returnClassName ++ " class!"
      L.printErr msg
      return $ Left $ stringToSupermonadError msg
    (_, _, _, _) -> do
      let msg = "Problems when finding supermonad instances:"
      let sdocErr = O.vcat smErrors
      L.printErr msg
      L.printErr $ L.sDocToStr sdocErr
      return $ Left $ stringToSupermonadError msg O.$$ sdocErr


-- | Execute the given 'TcPluginM' computation within the plugin monad.
runTcPlugin :: TcPluginM a -> SupermonadPluginM a
runTcPlugin = lift . lift . lift

-- -----------------------------------------------------------------------------
-- Plugin Environment Access
-- -----------------------------------------------------------------------------

-- | Returns the 'Control.Supermonad.Bind' class.
getBindClass :: SupermonadPluginM Class
getBindClass = asks smEnvBindClass

-- | Returns the 'Control.Supermonad.Return' class.
getReturnClass :: SupermonadPluginM Class
getReturnClass = asks smEnvReturnClass

-- | The 'Control.Supermonad' module.
getSupermonadModule :: SupermonadPluginM Module
getSupermonadModule = asks smEnvSupermonadModule

-- | Returns all of the /given/ and /derived/ constraints of this plugin call.
getGivenConstraints :: SupermonadPluginM [GivenCt]
getGivenConstraints = asks smEnvGivenConstraints

-- | Returns all of the wanted constraints of this plugin call.
getWantedConstraints :: SupermonadPluginM [WantedCt]
getWantedConstraints = asks smEnvWantedConstraints

-- | Returns all bind instances including those given by
--   'getBindApplyInstance' and 'getBindFunctorInstance'.
getBindInstances :: SupermonadPluginM [ClsInst]
getBindInstances = asks smEnvBindInstances

-- | Shortcut to access the instance environments.
getInstEnvs :: SupermonadPluginM InstEnvs
getInstEnvs = runTcPlugin TcPluginM.getInstEnvs

-- | Retrieves the supermonad bind and return instance (in that order) of the given type constructor,
--   if the type constructor represents a supermonad in scope.
getSupermonadFor :: TyCon -> SupermonadPluginM (Maybe (ClsInst, ClsInst))
getSupermonadFor tc = (return . M.lookup tc) =<< asks smEnvSupermonads

-- | Add another type variable equality to the derived equalities.
addTyVarEquality :: Ct -> TyVar -> Type -> SupermonadPluginM ()
addTyVarEquality ct tv ty = modify $ \s -> s { smStateTyVarEqualities = (ct, tv, ty) : smStateTyVarEqualities s }

-- | Add a list of type variable equalities to the derived equalities.
addTyVarEqualities :: [(Ct, TyVar, Type)] -> SupermonadPluginM ()
addTyVarEqualities = mapM_ (\(ct, tv, ty) -> addTyVarEquality ct tv ty)

-- | Add another type equality to the derived equalities.
addTypeEquality :: Ct -> Type -> Type -> SupermonadPluginM ()
addTypeEquality ct ta tb = modify $ \s -> s { smStateTypeEqualities = (ct, ta, tb) : smStateTypeEqualities s }

-- | Add a list of type equality to the derived equalities.
addTypeEqualities :: [(Ct, Type, Type)] -> SupermonadPluginM ()
addTypeEqualities = mapM_ (\(ct, ta, tb) -> addTypeEquality ct ta tb)

-- | Returns all derived type variable equalities that were added to the results thus far.
getTyVarEqualities :: SupermonadPluginM [(Ct, TyVar, Type)]
getTyVarEqualities = gets $ smStateTyVarEqualities

-- | Returns all derived type variable equalities that were added to the results thus far.
getTypeEqualities :: SupermonadPluginM [(Ct, Type, Type)]
getTypeEqualities = gets $ smStateTypeEqualities

-- | Add a warning to the queue of warnings that will be displayed when no progress could be made.
addWarning :: String -> O.SDoc -> SupermonadPluginM ()
addWarning msg details = modify $ \s -> s { smStateWarningQueue = (msg, details) : smStateWarningQueue s }

-- | Execute the given plugin code only if no plugin results were produced so far.
whenNoResults :: SupermonadPluginM () -> SupermonadPluginM ()
whenNoResults m = do
  tyVarEqs <- getTyVarEqualities
  tyEqs <- getTypeEqualities
  if null tyVarEqs && null tyEqs 
    then m 
    else return ()

-- | Displays the queued warning messages if no progress has been made.
displayWarnings :: SupermonadPluginM ()
displayWarnings = whenNoResults $ do
  warns <- gets smStateWarningQueue
  forM_ warns $ \(msg, details) -> do
    printWarn msg
    internalPrint $ L.smObjMsg $ L.sDocToStr details

-- -----------------------------------------------------------------------------
-- Plugin debug and error printing
-- -----------------------------------------------------------------------------

stringToSupermonadError :: String -> SupermonadError
stringToSupermonadError = O.text

-- | Assert the given condition. If the condition does not
--   evaluate to 'True', an error with the given message will
--   be thrown the plugin aborts.
assert :: Bool -> String -> SupermonadPluginM ()
assert cond msg = unless cond $ throwPluginError msg

-- | Assert the given condition. Same as 'assert' but with
--   a monadic condition.
assertM :: SupermonadPluginM Bool -> String -> SupermonadPluginM ()
assertM condM msg = do
  cond <- condM
  assert cond msg

-- | Throw an error with the given message in the plugin.
--   This will abort all further actions.
throwPluginError :: String -> SupermonadPluginM a
throwPluginError = throwError . stringToSupermonadError

-- | Throw an error with the given message in the plugin.
--   This will abort all further actions.
throwPluginErrorSDoc :: O.SDoc -> SupermonadPluginM a
throwPluginErrorSDoc = throwError

-- | Catch an error that was thrown by the plugin.
catchPluginError :: SupermonadPluginM a -> (SupermonadError -> SupermonadPluginM a) -> SupermonadPluginM a
catchPluginError = catchError

-- | Print some generic outputable object from the plugin (Unsafe).
printObj :: Outputable o => o -> SupermonadPluginM ()
printObj = internalPrint . L.smObjMsg . L.pprToStr

-- | Print a message from the plugin.
printMsg :: String -> SupermonadPluginM ()
printMsg = internalPrint . L.smDebugMsg

-- | Print an error message from the plugin.
printErr :: String -> SupermonadPluginM ()
printErr = internalPrint . L.smErrMsg

-- | Print a warning message from the plugin.
printWarn :: String -> SupermonadPluginM ()
printWarn = internalPrint . L.smWarnMsg

-- | Internal function for printing from within the monad.
internalPrint :: String -> SupermonadPluginM ()
internalPrint = runTcPlugin . tcPluginIO . putStr

-- | Print the given string as if it was an object. This allows custom
--   formatting of object.
printFormattedObj :: String -> SupermonadPluginM ()
printFormattedObj = internalPrint . L.smObjMsg

-- | Print the given constraints in the plugins custom format.
printConstraints :: [Ct] -> SupermonadPluginM ()
printConstraints cts =
  forM_ groupedCts $ \(file, ctGroup) -> do
    printFormattedObj $ maybe "From unknown file:" (("From " ++) . (++":") . unpackFS) file
    mapM_ (printFormattedObj . L.formatConstraint) ctGroup
  where
    groupedCts = (\ctGroup -> (getCtFile $ head ctGroup, ctGroup)) <$> groupBy eqFileName cts
    eqFileName ct1 ct2 = getCtFile ct1 == getCtFile ct2
    getCtFile = srcSpanFileName_maybe . constraintSourceLocation