{-# LANGUAGE MultiParamTypeClasses
           , TypeSynonymInstances
           , FlexibleInstances
           , ScopedTypeVariables
           , ViewPatterns
           , LambdaCase
           #-}
-- | Utility functions defined on the GHC AST representation.
module Language.Haskell.Tools.AST.FromGHC.GHCUtils where

import Data.Generics.Uniplate.Data ()
import Data.List

import Bag (Bag, bagToList, unionManyBags)
import ConLike (ConLike(..))
import GHC
import Id (Id, mkVanillaGlobal)
import OccName (OccName)
import Outputable (Outputable(..), OutputableBndr(..), showSDocUnsafe)
import PatSyn (patSynSig)
import RdrName (RdrName, rdrNameOcc, nameRdrName)
import SrcLoc
import Type (TyThing(..), mkFunTys)

class OutputableBndr name => GHCName name where
  rdrName :: name -> RdrName
  getFromNameUsing :: Applicative f => (Name -> Ghc (f Id)) -> Name -> Ghc (f name)
  getBindsAndSigs :: HsValBinds name -> ([LSig name], LHsBinds name)
  nameFromId :: Id -> name
  unpackPostRn :: RdrName -> PostRn name name -> name

  gunpackPostRn :: a -> (name -> a) -> PostRn name name -> a

instance GHCName RdrName where
  rdrName = id
  getFromNameUsing _ n = return $ pure (nameRdrName n)
  getBindsAndSigs (ValBindsIn binds sigs) = (sigs, binds)
  getBindsAndSigs _ = error "ValBindsOut: ValBindsOut in parsed source"
  nameFromId = nameRdrName . getName
  unpackPostRn rdr _ = rdr

  gunpackPostRn a _ _ = a

occName :: GHCName n => n -> OccName
occName = rdrNameOcc . rdrName

instance GHCName GHC.Name where
  rdrName = nameRdrName
  getFromNameUsing f n = fmap nameFromId <$> f n
  getBindsAndSigs (ValBindsOut bindGroups sigs) = (sigs, unionManyBags (map snd bindGroups))
  getBindsAndSigs _ = error "getBindsAndSigs: ValBindsIn in renamed source"
  nameFromId = getName
  unpackPostRn _ a = a

  gunpackPostRn _ f pr = f pr

getFieldOccName :: GHCName n => Located (FieldOcc n) -> Located n
getFieldOccName (L l (FieldOcc (L _ rdr) postRn)) = L l (unpackPostRn rdr postRn)

getFieldOccName' :: GHCName n => FieldOcc n -> n
getFieldOccName' (FieldOcc (L _ rdr) postRn) = unpackPostRn rdr postRn



-- | Loading ids for top-level ghc names
getTopLevelId :: GHC.Name -> Ghc (Maybe GHC.Id)
getTopLevelId name =
    lookupName name >>= \case
      Just (AnId id) -> return (Just id)
      Just (AConLike (RealDataCon dc)) -> return $ Just $ mkVanillaGlobal name (dataConUserType dc)
      Just (AConLike (PatSynCon ps)) -> return $ Just $ mkVanillaGlobal name (createPatSynType ps)
      Just (ATyCon tc) -> return $ Just $ mkVanillaGlobal name (tyConKind tc)
      _ -> return Nothing
  where createPatSynType patSyn = case patSynSig patSyn of (_, _, _, _, args, res) -> mkFunTys args res

-- | Get names from the GHC AST
class HsHasName a where
  hsGetNames :: a -> [GHC.Name]

instance HsHasName RdrName where
  hsGetNames _ = []

instance HsHasName Name where
  hsGetNames n = [n]

instance HsHasName Id where
  hsGetNames n = [getName n]

instance HsHasName e => HsHasName [e] where
  hsGetNames es = concatMap hsGetNames es

instance HsHasName e => HsHasName (Located e) where
  hsGetNames (L _ e) = hsGetNames e

instance HsHasName n => HsHasName (HsLocalBinds n) where
  hsGetNames (HsValBinds bnds) = hsGetNames bnds
  hsGetNames _ = []

instance (GHCName n, HsHasName n) => HsHasName (HsDecl n) where
  hsGetNames (TyClD tycl) = hsGetNames tycl
  hsGetNames (ValD vald) = hsGetNames vald
  hsGetNames (ForD ford) = hsGetNames ford
  hsGetNames (InstD inst) = hsGetNames inst
  hsGetNames _ = []

instance (GHCName n, HsHasName n) => HsHasName (InstDecl n) where
  hsGetNames (ClsInstD clsInst) = hsGetNames (cid_datafam_insts clsInst)
  hsGetNames (DataFamInstD dataFamInst) = hsGetNames dataFamInst
  hsGetNames _ = []

instance (GHCName n, HsHasName n) => HsHasName (DataFamInstDecl n) where
  hsGetNames dfid = hsGetNames (dfid_defn dfid)

instance (GHCName n, HsHasName n) => HsHasName (TyClGroup n) where
  hsGetNames (TyClGroup tycls _) = hsGetNames tycls

instance (GHCName n, HsHasName n) => HsHasName (TyClDecl n) where
  hsGetNames (FamDecl (FamilyDecl {fdLName = name})) = hsGetNames name
  hsGetNames (SynDecl {tcdLName = name}) = hsGetNames name
  hsGetNames (DataDecl {tcdLName = name, tcdDataDefn = datadef}) = hsGetNames name ++ hsGetNames datadef
  hsGetNames (ClassDecl {tcdLName = name, tcdSigs = sigs}) = hsGetNames name ++ hsGetNames sigs

instance (GHCName n, HsHasName n) => HsHasName (HsDataDefn n) where
  hsGetNames (HsDataDefn {dd_cons = ctors}) = hsGetNames ctors

instance (GHCName n, HsHasName n) => HsHasName (ConDecl n) where
  hsGetNames (ConDeclGADT {con_names = names, con_type = (HsIB _ (L _ (HsRecTy flds)))})
    = hsGetNames names ++ hsGetNames flds
  hsGetNames (ConDeclGADT {con_names = names}) = hsGetNames names
  hsGetNames (ConDeclH98 {con_name = name, con_details = details}) = hsGetNames name ++ hsGetNames details

instance (GHCName n, HsHasName n) => HsHasName (HsConDeclDetails n) where
  hsGetNames (RecCon rec) = hsGetNames rec
  hsGetNames _ = []

instance (GHCName n, HsHasName n) => HsHasName (ConDeclField n) where
  hsGetNames (ConDeclField name _ _) = hsGetNames name

instance (GHCName n, HsHasName n) => HsHasName (FieldOcc n) where
  hsGetNames (FieldOcc _ pr) = gunpackPostRn [] (hsGetNames :: n -> [Name]) pr

instance (HsHasName n) => HsHasName (Sig n) where
  hsGetNames (TypeSig n _) = hsGetNames n
  hsGetNames (PatSynSig n _) = hsGetNames n
  hsGetNames _ = []

instance HsHasName n => HsHasName (ForeignDecl n) where
  hsGetNames (ForeignImport n _ _ _) = hsGetNames n
  hsGetNames _ = []

instance HsHasName n => HsHasName (HsValBinds n) where
  hsGetNames (ValBindsIn bnds _) = hsGetNames bnds
  hsGetNames (ValBindsOut bnds _) = hsGetNames $ map snd bnds

instance HsHasName n => HsHasName (Bag n) where
  hsGetNames = hsGetNames . bagToList

instance HsHasName n => HsHasName (HsBind n) where
  hsGetNames (FunBind {fun_id = lname}) = hsGetNames lname
  hsGetNames (PatBind {pat_lhs = pat}) = hsGetNames pat
  hsGetNames (VarBind {var_id = id}) = hsGetNames id
  hsGetNames (PatSynBind (PSB {psb_id = id})) = hsGetNames id
  hsGetNames _ = error "hsGetNames: called on compiler-generated binding"

instance HsHasName n => HsHasName (ParStmtBlock l n) where
  hsGetNames (ParStmtBlock _ binds _) = hsGetNames binds

--instance HsHasName n => HsHasName (LHsTyVarBndrs n) where
--  hsGetNames (HsQTvs kvs tvs) = hsGetNames kvs ++ hsGetNames tvs

instance HsHasName n => HsHasName (HsTyVarBndr n) where
  hsGetNames (UserTyVar n) = hsGetNames n
  hsGetNames (KindedTyVar n _) = hsGetNames n

instance HsHasName n => HsHasName (Stmt n b) where
  hsGetNames (LetStmt binds) = hsGetNames binds
  hsGetNames (BindStmt pat _ _ _ _) = hsGetNames pat
  hsGetNames (RecStmt {recS_rec_ids = ids}) = hsGetNames ids
  hsGetNames _ = []

instance HsHasName n => HsHasName (Pat n) where
  hsGetNames (VarPat id) = hsGetNames id
  hsGetNames (LazyPat p) = hsGetNames p
  hsGetNames (AsPat lname p) = hsGetNames lname ++ hsGetNames p
  hsGetNames (ParPat p) = hsGetNames p
  hsGetNames (BangPat p) = hsGetNames p
  hsGetNames (ListPat pats _ _) = concatMap hsGetNames pats
  hsGetNames (TuplePat pats _ _) = concatMap hsGetNames pats
  hsGetNames (PArrPat pats _) = concatMap hsGetNames pats
  hsGetNames (ConPatIn _ details) = concatMap hsGetNames (hsConPatArgs details)
  hsGetNames (ConPatOut {pat_args = details}) = concatMap hsGetNames (hsConPatArgs details)
  hsGetNames (ViewPat _ p _) = hsGetNames p
  hsGetNames (NPlusKPat lname _ _ _ _ _) = hsGetNames lname
  hsGetNames (SigPatIn p _) = hsGetNames p
  hsGetNames (SigPatOut p _) = hsGetNames p
  hsGetNames _ = []

instance (GHCName n, HsHasName n) => HsHasName (HsGroup n) where
  hsGetNames (HsGroup vals _ clds insts _ _ _ foreigns _ _ _ _ _) = hsGetNames vals ++ hsGetNames clds ++ hsGetNames insts ++ hsGetNames foreigns

-- | Get the original form of a name
rdrNameStr :: RdrName -> String
rdrNameStr name = showSDocUnsafe $ ppr name


class FromGHCName n where
  fromGHCName :: GHC.Name -> n

instance FromGHCName RdrName where
  fromGHCName = rdrName

instance FromGHCName GHC.Name where
  fromGHCName = id

-- | Tries to simplify the type that has HsAppsTy before renaming. Does not always provide the correct form.
-- Treats each operator as if they are of equivalent precedence and always left-associative.
cleanHsType :: forall n . (OutputableBndr n) => HsType n -> HsType n
-- for some reason * is considered infix
cleanHsType (HsAppsTy apps) = unLoc $ guessType apps
  where guessType :: OutputableBndr n => [LHsAppType n] -> LHsType n
        guessType (L l (HsAppInfix n) : rest) -- must be a prefix actually
          = guessType' (L l (HsTyVar n)) rest
        guessType (L _ (HsAppPrefix t) : rest) = guessType' t rest
        guessType [] = error $ "guessType: empty: " ++ showSDocUnsafe (ppr apps)

        guessType' :: LHsType n -> [LHsAppType n] -> LHsType n
        guessType' fun (L l (HsAppPrefix t) : rest) = guessType' (hsAppTy fun t) rest
        guessType' fun (L l (HsAppInfix n) : rest)
          -- TODO: find a better check
          | showSDocUnsafe (ppr n) == "*" = guessType' (hsAppTy fun (L l (HsTyVar n))) rest
        guessType' left (L l (HsAppInfix n) : right) = hsOpTy left n (guessType right)
        guessType' t [] = t

        hsAppTy :: LHsType n -> LHsType n -> LHsType n
        hsAppTy t1 t2 = L (getLoc t1 `combineSrcSpans` getLoc t2) $ HsAppTy t1 t2

        hsOpTy :: LHsType n -> Located n -> LHsType n -> LHsType n
        hsOpTy t1 n t2 = L (getLoc t1 `combineSrcSpans` getLoc t2) $ HsOpTy t1 n t2
cleanHsType t = t

mergeFixityDefs :: [Located (FixitySig n)] -> [Located (FixitySig n)]
mergeFixityDefs (s@(L l _) : rest)
  = let (same, different) = partition ((== l) . getLoc) rest
     in foldl mergeWith s (map unLoc same) : mergeFixityDefs different
  where mergeWith (L l (FixitySig names fixity)) (FixitySig otherNames _) = L l (FixitySig (names ++ otherNames) fixity)
mergeFixityDefs [] = []