{-# LANGUAGE CPP #-}

module HsDev.Inspect.Definitions (
        getSymbols,
        getDecl
        ) where

import Control.Lens
import Control.Monad
import Data.Data (Data)
import Data.Generics.Uniplate.Data
import Data.List
import Data.Maybe
import Data.Function
import Data.Ord
import Data.String
import Data.Text (Text)
import qualified Language.Haskell.Exts as H

import HsDev.Symbols.Types
import HsDev.Symbols.Parsed
import HsDev.Symbols.Resolve (symbolUniqId)

-- | Get top symbols
getSymbols :: [H.Decl Ann] -> [Symbol]
getSymbols decls =
        map mergeSymbols .
        groupBy ((==) `on` symbolUniqId) .
        sortBy (comparing symbolUniqId) $
        concatMap getDecl decls
        where
                mergeSymbols :: [Symbol] -> Symbol
                mergeSymbols [] = error "impossible"
                mergeSymbols [s] = s
                mergeSymbols ss@(s:_) = Symbol
                        (view symbolId s)
                        (msum $ map (view symbolDocs) ss)
                        (msum $ map (view symbolPosition) ss)
                        (foldr1 mergeInfo $ map (view symbolInfo) ss)

                mergeInfo :: SymbolInfo -> SymbolInfo -> SymbolInfo
                mergeInfo (Function lt) (Function rt) = Function $ lt `mplus` rt
                mergeInfo (PatConstructor las lt) (PatConstructor ras rt) = PatConstructor (if null las then ras else las) (lt `mplus` rt)
                mergeInfo (Selector lt lp lc) (Selector rt rp rc)
                        | lt == rt && lp == rp = Selector lt lp (nub $ lc ++ rc)
                        | otherwise = Selector lt lp lc
                mergeInfo l _ = l

-- | Get symbols from declarations
getDecl :: H.Decl Ann -> [Symbol]
getDecl decl' = case decl' of
        H.TypeDecl _ h _ -> [mkSymbol (tyName h) (Type (tyArgs h) [])]
        H.TypeFamDecl _ h _ _ -> [mkSymbol (tyName h) (TypeFam (tyArgs h) [] Nothing)]
        H.ClosedTypeFamDecl _ h _ _ _ -> [mkSymbol (tyName h) (TypeFam (tyArgs h) [] Nothing)]
        H.DataDecl _ dt mctx h dcons _ -> mkSymbol nm ((getCtor dt) (tyArgs h) (getCtx mctx)) : concatMap (getConDecl nm) dcons where
                nm = tyName h
        H.GDataDecl _ dt mctx h _ gcons _ -> mkSymbol nm ((getCtor dt) (tyArgs h) (getCtx mctx)) : concatMap (getGConDecl nm) gcons where
                nm = tyName h
        H.DataFamDecl _ mctx h _ -> [mkSymbol (tyName h) (DataFam (tyArgs h) (getCtx mctx) Nothing)]
        H.ClassDecl _ mctx h _ clsDecls -> mkSymbol nm (Class (tyArgs h) (getCtx mctx)) : concatMap (getClassDecl nm) (fromMaybe [] clsDecls) where
                nm = tyName h
        H.TypeSig _ ns tsig -> [mkSymbol n (Function (Just $ oneLinePrint tsig)) | n <- ns]
        H.PatSynSig _ ns mas _ _ t -> [mkSymbol n (PatConstructor (maybe [] (map prp) mas) (Just $ oneLinePrint t)) | n <- ns'] where
#if MIN_VERSION_haskell_src_exts(1,20,0)
                ns' = ns
#else
                ns' = [ns]
#endif
        H.FunBind _ ms -> [mkSymbol (matchName m) (Function Nothing) | m <- ms] where
                matchName (H.Match _ n _ _ _) = n
                matchName (H.InfixMatch _ _ n _ _ _) = n
        H.PatBind _ p _ _ -> [mkSymbol n (Function Nothing) | n <- patNames p] where
                patNames :: H.Pat Ann -> [H.Name Ann]
                patNames = childrenBi
        H.PatSyn _ p _ _ -> case p of
                H.PInfixApp _ _ qn _ -> [mkSymbol (qToName qn) (PatConstructor [] Nothing)]
                H.PApp _ qn _ -> [mkSymbol (qToName qn) (PatConstructor [] Nothing)]
                H.PRec _ qn fs -> mkSymbol (qToName qn) (PatConstructor [] Nothing) :
                        [mkSymbol (qToName n) (PatSelector Nothing Nothing (prp $ qToName qn)) | n <- (universeBi fs :: [H.QName Ann])]
                _ -> []
                where
                        qToName (H.Qual _ _ n) = n
                        qToName (H.UnQual _ n) = n
                        qToName _ = error "invalid qname"
        _ -> []
        where
                tyName :: H.DeclHead Ann -> H.Name Ann
                tyName = head . universeBi
                tyArgs :: Data (ast Ann) => ast Ann -> [Text]
                tyArgs n = map prp (universeBi n :: [H.TyVarBind Ann])
                getCtx :: Maybe (H.Context Ann) -> [Text]
                getCtx mctx = map prp (universeBi mctx :: [H.Asst Ann])
                getCtor (H.DataType _) = Data
                getCtor (H.NewType _) = NewType

getConDecl :: H.Name Ann -> H.QualConDecl Ann -> [Symbol]
getConDecl ptype (H.QualConDecl _ _ _ cdecl) = case cdecl of
        H.ConDecl _ n ts -> [mkSymbol n (Constructor (map prp ts) (prp ptype))]
        H.InfixConDecl _ lt n rt -> [mkSymbol n (Constructor (map prp [lt, rt]) (prp ptype))]
        H.RecDecl _ n fs -> mkSymbol n (Constructor [prp t | H.FieldDecl _ _ t <- fs] (prp ptype)) :
                [mkSymbol fn (Selector (Just $ prp ft) (prp ptype) [prp n]) | H.FieldDecl _ fns ft <- fs, fn <- fns]

getGConDecl :: H.Name Ann -> H.GadtDecl Ann -> [Symbol]
getGConDecl _ (H.GadtDecl _ n Nothing t) = [mkSymbol n (Constructor (map prp as) (prp res))] where
        (as, res) = tyFunSplit t
        tyFunSplit = go [] where
                go as' (H.TyFun _ arg' res') = go (arg' : as') res'
                go as' t' = (reverse as', t')
getGConDecl ptype (H.GadtDecl _ n (Just fs) t) = mkSymbol n (Constructor [prp ft | H.FieldDecl _ _ ft <- fs] (prp t)) :
        [mkSymbol fn (Selector (Just $ prp ft) (prp ptype) [prp n]) | H.FieldDecl _ fns ft <- fs, fn <- fns]

getClassDecl :: H.Name Ann -> H.ClassDecl Ann -> [Symbol]
getClassDecl pclass (H.ClsDecl _ (H.TypeSig _ ns tsig)) = [mkSymbol n (Method (Just $ oneLinePrint tsig) (prp pclass)) | n <- ns]
getClassDecl _ _ = []

prp :: H.Pretty a => a -> Text
prp = fromString . H.prettyPrint


mkSymbol :: H.Name Ann -> SymbolInfo -> Symbol
mkSymbol nm = Symbol (SymbolId (fromName_ $ void nm) (ModuleId (fromString "") noLocation)) Nothing (nm ^? binders . defPos)


-- | Print something in one line
oneLinePrint :: (H.Pretty a, IsString s) => a -> s
oneLinePrint = fromString . H.prettyPrintStyleMode (H.style { H.mode = H.OneLineMode }) H.defaultMode