{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-redundant-constraints -Wno-unused-matches #-}

module Calligraphy.Compat.Lib
  ( sourceInfo,
    showContextInfo,
    readHieFileCompat,
    isInstanceNode,
    isTypeSignatureNode,
    isInlineNode,
    isMinimalNode,
    isDerivingNode,
    showAnns,
    mergeSpans,
    isPointSpan,
    getHieFiles,
  )
where

import qualified Calligraphy.Compat.GHC as GHC
import Calligraphy.Util.Lens
import Data.IORef
import qualified Data.Set as Set
import Control.Monad

#if MIN_VERSION_ghc(9,0,0)
import GHC.Iface.Ext.Binary
import GHC.Iface.Ext.Types
import GHC.Types.Name.Cache
import GHC.Types.SrcLoc
import GHC.Utils.Outputable (ppr, showSDocUnsafe)
import qualified Data.Map as Map
#else
import HieBin
import HieTypes
import NameCache
import SrcLoc
#endif

getHieFiles :: [FilePath] -> IO [HieFile]
#if MIN_VERSION_ghc(9,4,0)

getHieFiles filePaths = do
  ref <- newIORef =<< GHC.initNameCache 'z' []
  forM filePaths (readHieFileWithWarning ref)

#else

getHieFiles :: [String] -> IO [HieFile]
getHieFiles [String]
filePaths = do
    UniqSupply
uniqSupply <- Char -> IO UniqSupply
GHC.mkSplitUniqSupply Char
'z'
    IORef NameCache
ref <- forall a. a -> IO (IORef a)
newIORef (UniqSupply -> [Name] -> NameCache
GHC.initNameCache UniqSupply
uniqSupply [])
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [String]
filePaths (IORef NameCache -> String -> IO HieFile
readHieFileWithWarning IORef NameCache
ref)

#endif

readHieFileWithWarning :: IORef GHC.NameCache -> FilePath -> IO GHC.HieFile
readHieFileWithWarning :: IORef NameCache -> String -> IO HieFile
readHieFileWithWarning IORef NameCache
ref String
path = do
  GHC.HieFileResult Integer
fileHieVersion ByteString
fileGHCVersion HieFile
hie <- IORef NameCache -> String -> IO HieFileResult
readHieFileCompat IORef NameCache
ref String
path
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Integer
GHC.hieVersion forall a. Eq a => a -> a -> Bool
/= Integer
fileHieVersion) forall a b. (a -> b) -> a -> b
$ do
    String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"WARNING: version mismatch in " forall a. Semigroup a => a -> a -> a
<> String
path
    String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"    The hie files in this project were generated with GHC version: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show ByteString
fileGHCVersion
    String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"    This version of calligraphy was compiled with GHC version: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Integer
GHC.hieVersion
    String -> IO ()
putStrLn String
"    Optimistically continuing anyway..."
  forall (f :: * -> *) a. Applicative f => a -> f a
pure HieFile
hie

{-# INLINE sourceInfo #-}
sourceInfo :: Traversal' (HieAST a) (NodeInfo a)
showContextInfo :: ContextInfo -> String
readHieFileCompat :: IORef NameCache -> FilePath -> IO HieFileResult

#if MIN_VERSION_ghc(9,4,0)

sourceInfo f (Node (SourcedNodeInfo inf) sp children) = (\inf' -> Node (SourcedNodeInfo inf') sp children) <$> Map.alterF (maybe (pure Nothing) (fmap Just . f)) SourceInfo inf

showContextInfo = showSDocUnsafe . ppr

readHieFileCompat ref path = do
  nameCache <- readIORef ref
  readHieFile nameCache path

#elif MIN_VERSION_ghc(9,0,0)

sourceInfo :: forall a. Traversal' (HieAST a) (NodeInfo a)
sourceInfo NodeInfo a -> m (NodeInfo a)
f (Node (SourcedNodeInfo Map NodeOrigin (NodeInfo a)
inf) Span
sp [HieAST a]
children) = (\Map NodeOrigin (NodeInfo a)
inf' -> forall a. SourcedNodeInfo a -> Span -> [HieAST a] -> HieAST a
Node (forall a. Map NodeOrigin (NodeInfo a) -> SourcedNodeInfo a
SourcedNodeInfo Map NodeOrigin (NodeInfo a)
inf') Span
sp [HieAST a]
children) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (f :: * -> *) k a.
(Functor f, Ord k) =>
(Maybe a -> f (Maybe a)) -> k -> Map k a -> f (Map k a)
Map.alterF (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing) (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeInfo a -> m (NodeInfo a)
f)) NodeOrigin
SourceInfo Map NodeOrigin (NodeInfo a)
inf

showContextInfo :: ContextInfo -> String
showContextInfo = SDoc -> String
showSDocUnsafe forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Outputable a => a -> SDoc
ppr

readHieFileCompat :: IORef NameCache -> String -> IO HieFileResult
readHieFileCompat IORef NameCache
ref = NameCacheUpdater -> String -> IO HieFileResult
readHieFile ((forall c. (NameCache -> (NameCache, c)) -> IO c)
-> NameCacheUpdater
NCU (forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef NameCache
ref))

#else

sourceInfo f (Node inf sp children) = (\inf' -> Node inf' sp children) <$> f inf

showContextInfo = show

readHieFileCompat ref fp = do
  cache <- readIORef ref
  (res, cache') <- readHieFile cache fp
  writeIORef ref cache'
  pure res

#endif

isInstanceNode :: NodeInfo a -> Bool
isTypeSignatureNode :: NodeInfo a -> Bool
isInlineNode :: NodeInfo a -> Bool
isMinimalNode :: NodeInfo a -> Bool
isDerivingNode :: NodeInfo a -> Bool
showAnns :: NodeInfo a -> String
#if MIN_VERSION_ghc(9,2,0)

isInstanceNode :: forall a. NodeInfo a -> Bool
isInstanceNode (NodeInfo Set NodeAnnotation
anns [a]
_ NodeIdentifiers a
_) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Ord a => a -> Set a -> Bool
Set.member Set NodeAnnotation
anns) [FastString -> FastString -> NodeAnnotation
NodeAnnotation FastString
"ClsInstD" FastString
"InstDecl", FastString -> FastString -> NodeAnnotation
NodeAnnotation FastString
"DerivDecl" FastString
"DerivDecl"]

isTypeSignatureNode :: forall a. NodeInfo a -> Bool
isTypeSignatureNode (NodeInfo Set NodeAnnotation
anns [a]
_ NodeIdentifiers a
_) = forall a. Ord a => a -> Set a -> Bool
Set.member (FastString -> FastString -> NodeAnnotation
NodeAnnotation FastString
"TypeSig" FastString
"Sig") Set NodeAnnotation
anns

isInlineNode :: forall a. NodeInfo a -> Bool
isInlineNode (NodeInfo Set NodeAnnotation
anns [a]
_ NodeIdentifiers a
_) = forall a. Ord a => a -> Set a -> Bool
Set.member (FastString -> FastString -> NodeAnnotation
NodeAnnotation FastString
"InlineSig" FastString
"Sig") Set NodeAnnotation
anns

isMinimalNode :: forall a. NodeInfo a -> Bool
isMinimalNode (NodeInfo Set NodeAnnotation
anns [a]
_ NodeIdentifiers a
_) = forall a. Ord a => a -> Set a -> Bool
Set.member (FastString -> FastString -> NodeAnnotation
NodeAnnotation FastString
"MinimalSig" FastString
"Sig") Set NodeAnnotation
anns

isDerivingNode :: forall a. NodeInfo a -> Bool
isDerivingNode (NodeInfo Set NodeAnnotation
anns [a]
_ NodeIdentifiers a
_) = forall a. Ord a => a -> Set a -> Bool
Set.member (FastString -> FastString -> NodeAnnotation
NodeAnnotation FastString
"HsDerivingClause" FastString
"HsDerivingClause") Set NodeAnnotation
anns

showAnns :: forall a. NodeInfo a -> String
showAnns (NodeInfo Set NodeAnnotation
anns [a]
_ NodeIdentifiers a
_) = [String] -> String
unwords (forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. NodeAnnotation -> (FastString, FastString)
unNodeAnnotation forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Set a -> [a]
Set.toList Set NodeAnnotation
anns)
  where
    unNodeAnnotation :: NodeAnnotation -> (FastString, FastString)
unNodeAnnotation (NodeAnnotation FastString
a FastString
b) = (FastString
a, FastString
b)

#else

isInstanceNode (NodeInfo anns _ _) = any (flip Set.member anns) [("ClsInstD", "InstDecl"), ("DerivDecl", "DerivDecl")]

isTypeSignatureNode (NodeInfo anns _ _) = Set.member ("TypeSig", "Sig") anns

isInlineNode (NodeInfo anns _ _) = Set.member ("InlineSig", "Sig") anns

isMinimalNode (NodeInfo anns _ _) = Set.member ("MinimalSig", "Sig") anns

isDerivingNode (NodeInfo anns _ _) = Set.member ("HsDerivingClause", "HsDerivingClause") anns

showAnns (NodeInfo anns _ _) = unwords (show <$> Set.toList anns)

#endif

mergeSpans :: Span -> Span -> Span
mergeSpans :: Span -> Span -> Span
mergeSpans Span
sp1 Span
sp2 =
  RealSrcLoc -> RealSrcLoc -> Span
mkRealSrcSpan
    ( forall a. Ord a => a -> a -> a
min
        (Span -> RealSrcLoc
realSrcSpanStart Span
sp1)
        (Span -> RealSrcLoc
realSrcSpanStart Span
sp2)
    )
    ( forall a. Ord a => a -> a -> a
max
        (Span -> RealSrcLoc
realSrcSpanEnd Span
sp1)
        (Span -> RealSrcLoc
realSrcSpanEnd Span
sp2)
    )

isPointSpan :: Span -> Bool
isPointSpan :: Span -> Bool
isPointSpan Span
sp = Span -> RealSrcLoc
realSrcSpanEnd Span
sp forall a. Ord a => a -> a -> Bool
<= Span -> RealSrcLoc
realSrcSpanStart Span
sp