{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
module Main (main) where

import qualified Control.Exception                                 as E
import           Control.Monad                                     (when)
import           Data.Aeson                                        (encode)
import qualified Data.ByteString                                   as B
import qualified Data.ByteString.Lazy                              as BL
import           Data.Fix                                          (Fix (..),
                                                                    foldFix)
import           Data.Functor.Identity                             (Identity (..),
                                                                    runIdentity)
import           Data.List                                         (find, nub)
import           Data.Map.Strict                                   (Map)
import qualified Data.Map.Strict                                   as Map
import           Data.Maybe                                        (mapMaybe)
import           Data.Set                                          (Set)
import qualified Data.Set                                          as Set
import           Data.Text                                         (Text)
import qualified Data.Text                                         as T
import qualified Data.Text.Encoding                                as T
import qualified Data.Text.Encoding.Error                          as T
import qualified Data.Text.IO                                      as Text
import qualified Language.Cimple                                   as C
import           Language.Cimple.Analysis.ArrayUsageAnalysis       (runArrayUsageAnalysis)
import           Language.Cimple.Analysis.CallGraphAnalysis        (CallGraphResult (..),
                                                                    runCallGraphAnalysis)
import           Language.Cimple.Analysis.ConstraintGeneration     (ConstraintGenResult (..),
                                                                    runConstraintGeneration)
import qualified Language.Cimple.Analysis.ConstraintGeneration     as CG
import           Language.Cimple.Analysis.Errors                   (Context (..),
                                                                    ErrorInfo (..))
import           Language.Cimple.Analysis.GlobalStructuralAnalysis (GlobalAnalysisResult (..),
                                                                    runGlobalStructuralAnalysis)
import           Language.Cimple.Analysis.NullabilityAnalysis      (runNullabilityAnalysis)
import           Language.Cimple.Analysis.OrderedSolver            (OrderedSolverResult (..),
                                                                    runOrderedSolver)
import           Language.Cimple.Analysis.Pretty                   (ppErrorInfo)
import           Language.Cimple.Analysis.Refined.Inference        (RefinedResult (..),
                                                                    inferRefined)
import           Language.Cimple.Analysis.Refined.Registry         (Registry (..))
import           Language.Cimple.Analysis.Refined.Types            (AnyRigidNodeF (..))
import           Language.Cimple.Analysis.TypeCheck                (typeCheckProgram)
import qualified Language.Cimple.Analysis.TypeCheck.Constraints    as TC
import           Language.Cimple.Analysis.TypeCheck.Solver         (solveConstraints)
import           Language.Cimple.Hic.Analyze                       (nodeName)
import           Language.Cimple.Hic.Ast                           (HicNode (..),
                                                                    Node,
                                                                    NodeF (..))
import           Language.Cimple.Hic.Pretty                        (ppNode)
import           Language.Cimple.Hic.Program                       (Program (..),
                                                                    fromCimple,
                                                                    toCimple)
import qualified Language.Cimple.IO                                as CIO
import qualified Language.Cimple.Program                           as Program
import           Options.Applicative
import           Prettyprinter                                     (Doc, defaultLayoutOptions,
                                                                    layoutPretty,
                                                                    unAnnotate)
import qualified Prettyprinter.Render.Terminal                     as Terminal
import qualified Prettyprinter.Render.Text                         as TextRender
import           System.Exit                                       (ExitCode (..),
                                                                    exitFailure,
                                                                    exitSuccess)
import           System.IO                                         (hIsTerminalDevice,
                                                                    stdout)
import           System.Process                                    (callProcess)
import           Text.Groom                                        (groom)

data Phase
    = PhaseGlobalStructural
    | PhaseArrayUsage
    | PhaseCallGraph
    | PhaseNullability
    | PhaseConstraintGen
    | PhaseSolving
    | PhaseHicInference
    | PhaseRefinedSolver
    deriving (Show, Eq, Ord, Enum, Bounded)

phaseName :: Phase -> String
phaseName = \case
    PhaseGlobalStructural -> "global-structural"
    PhaseArrayUsage       -> "array-usage"
    PhaseCallGraph        -> "call-graph"
    PhaseNullability      -> "nullability"
    PhaseConstraintGen    -> "constraint-gen"
    PhaseSolving          -> "solving"
    PhaseHicInference     -> "hic-inference"
    PhaseRefinedSolver    -> "refined-solver"

parsePhase :: String -> Either String Phase
parsePhase s =
    case find (\(p) -> phaseName p == s) [minBound .. maxBound] of
        Just p  -> Right p
        Nothing -> Left $ "Unknown phase: " ++ s

data SolverType = SolverOrdered | SolverSimple
    deriving (Show, Eq, Ord, Enum, Bounded)

solverName :: SolverType -> String
solverName = \case
    SolverOrdered -> "ordered"
    SolverSimple  -> "simple"

parseSolver :: String -> Either String SolverType
parseSolver s =
    case find (\(p) -> solverName p == s) [minBound .. maxBound] of
        Just p  -> Right p
        Nothing -> Left $ "Unknown solver: " ++ s

data Options = Options
    { optInputs     :: [FilePath]
    , optExemplars  :: Bool
    , optDumpJson   :: Maybe FilePath
    , optStopAfter  :: Phase
    , optMaxErrors  :: Int
    , optSolver     :: SolverType
    , optNoOwner    :: Bool
    , optNoNullable :: Bool
    , optColor      :: Bool
    }

options :: Parser Options
options = Options
    <$> some (strArgument (metavar "FILE..." <> help "Input C files"))
    <*> switch (long "exemplars" <> help "Show exemplars of inferred structures")
    <*> optional (strOption (long "dump-json" <> metavar "BASENAME" <> help "Dump analysis results for each phase to BASENAME-<phase>.json"))
    <*> option (eitherReader parsePhase)
        (  long "stop-after"
        <> metavar "PHASE"
        <> value PhaseHicInference
        <> showDefault
        <> help "Stop after a specific phase (global-structural, array-usage, call-graph, nullability, constraint-gen, solving, hic-inference, refined-solver)"
        )
    <*> option auto
        (  long "max-errors"
        <> metavar "COUNT"
        <> value 5
        <> showDefault
        <> help "Maximum number of errors to display"
        )
    <*> option (eitherReader parseSolver)
        (  long "solver"
        <> metavar "SOLVER"
        <> value SolverOrdered
        <> showDefault
        <> help "Solver to use (ordered, simple)"
        )
    <*> switch (long "no-owner" <> help "Disable owner checks")
    <*> switch (long "no-nullable" <> help "Disable nullable/nonnull checks")
    <*> switch (long "color" <> help "Always output color diagnostics")

renderDoc :: Bool -> Doc Terminal.AnsiStyle -> IO ()
renderDoc forceColor doc = do
    isTerm <- hIsTerminalDevice stdout
    if isTerm || forceColor
        then Terminal.renderIO stdout (layoutPretty defaultLayoutOptions doc)
        else TextRender.renderIO stdout (layoutPretty defaultLayoutOptions (unAnnotate doc))

filterProgram :: Options -> Program.Program Text -> Program.Program Text
filterProgram opts prog =
    let tus = Program.toList prog
        tus' = runIdentity $ C.mapAst actions tus
    in case Program.fromList tus' of
        Left err -> error $ "filterProgram: " ++ err
        Right p  -> p
  where
    actions :: C.IdentityActions Identity Text
    actions = C.identityActions
        { C.doNode = \_ _ next -> do
            n <- next
            case unFix n of
                C.TyOwner i | optNoOwner opts -> return i
                C.TyNonnull i | optNoNullable opts -> return i
                C.TyNullable i | optNoNullable opts -> return i
                C.NonNullParam i | optNoNullable opts -> return i
                C.NullableParam i | optNoNullable opts -> return i
                C.DeclSpecArray _ size | optNoNullable opts -> return $ Fix $ C.DeclSpecArray C.NullabilityUnspecified size
                _ -> return n
        }

main :: IO ()
main = E.handle handler $ do
    opts <- execParser (info (options <**> helper) fullDesc)
    result <- CIO.parseProgram (optInputs opts)
    case result of
        Left err -> do
            putStrLn $ "Parse error: " ++ err
            exitFailure
        Right program' -> do
            let program = filterProgram opts program'
            let runPhase p act = do
                    res <- act
                    case optDumpJson opts of
                        Just base -> do
                            let path = base ++ "-" ++ phaseName p ++ ".json"
                            putStrLn $ "Dumping " ++ show p ++ " to " ++ path ++ "..."
                            BL.writeFile path (encode res)
                        Nothing -> return ()
                    if p == optStopAfter opts
                        then exitSuccess
                        else return res

            -- Phase 1: Global Structural Analysis
            globalAnalysis <- runPhase PhaseGlobalStructural $ do
                putStrLn "Phase 1: Global Structural Analysis..."
                return $ runGlobalStructuralAnalysis program

            -- Phase 2: Array Usage Analysis
            arrayUsage <- runPhase PhaseArrayUsage $ do
                putStrLn "Phase 2: Array Usage Analysis..."
                return $ runArrayUsageAnalysis (garTypeSystem globalAnalysis) program

            -- Phase 3: Call Graph Analysis
            callGraph <- runPhase PhaseCallGraph $ do
                putStrLn "Phase 3: Call Graph Analysis..."
                return $ runCallGraphAnalysis program

            -- Phase 4: Nullability Analysis
            nullability <- runPhase PhaseNullability $ do
                putStrLn "Phase 4: Nullability Analysis..."
                return $ runNullabilityAnalysis program

            -- Phase 5: Constraint Generation
            constraintGen <- runPhase PhaseConstraintGen $ do
                putStrLn "Phase 5: Constraint Generation..."
                return $ runConstraintGeneration (garTypeSystem globalAnalysis) arrayUsage nullability program

            -- Phase 5: Solving (Type Checking)
            _ <- runPhase PhaseSolving $ do
                putStrLn $ "Phase 5: Constraint Solving (using " ++ solverName (optSolver opts) ++ " solver)..."
                let errors = case optSolver opts of
                        SolverOrdered ->
                            let osr = runOrderedSolver (garTypeSystem globalAnalysis) (cgrSccs callGraph) constraintGen
                            in osrErrors osr
                        SolverSimple ->
                            let mapConstraint = \case
                                    CG.Equality t1 t2 ml ctx r -> Just $ TC.Equality t1 t2 ml ctx r
                                    CG.Subtype t1 t2 ml ctx r  -> Just $ TC.Subtype t1 t2 ml ctx r
                                    CG.Callable t1 atys _rt ml ctx csId sr -> Just $ TC.Callable t1 atys ml ctx csId sr
                                    CG.MemberAccess t1 f mt ml ctx r -> Just $ TC.MemberAccess t1 f mt ml ctx r
                                    CG.CoordinatedPair tr a e ml ctx _mCsId -> Just $ TC.CoordinatedPair tr a e ml ctx
                                    CG.Lub {} -> Nothing
                            in solveConstraints (garTypeSystem globalAnalysis) (mapMaybe mapConstraint $ concat $ Map.elems $ CG.cgrConstraints constraintGen)

                let extractPath ei = case find isFile (errContext ei) of
                        Just (InFile p) -> p
                        _               -> "unknown"
                    isFile = \case InFile _ -> True; _ -> False

                if null errors
                    then putStrLn "Type check successful."
                    else do
                        putStrLn "Type check failed with the following errors:"
                        let paths = nub $ map extractPath errors
                        fileCache <- Map.fromList <$> mapM (\(p) -> do
                            if p == "unknown"
                                then return (p, [])
                                else do
                                    content <- T.decodeUtf8With T.lenientDecode <$> B.readFile p
                                    return (p, T.lines content)) paths

                        mapM_ (\ei -> do
                            let path = extractPath ei
                            let mSnippet = case errLoc ei of
                                    Just (C.L (C.AlexPn _ lineNum _) _ _) -> do
                                        ls <- Map.lookup path fileCache
                                        if lineNum > 0 && lineNum <= length ls
                                            then Just (ls !! (lineNum - 1))
                                            else Nothing
                                    Nothing -> Nothing
                            renderDoc (optColor opts) (ppErrorInfo path ei mSnippet)
                            putStrLn "") (take (optMaxErrors opts) errors)
                        when (length errors > optMaxErrors opts) $
                            putStrLn $ "... and " ++ show (length errors - optMaxErrors opts) ++ " more errors elided."
                return ()

            -- Phase 6: Hic Inference
            putStrLn "Phase 6: Global Inference..."
            let hicProgram = fromCimple program
            let stats = collectStats hicProgram

            if optExemplars opts
                then showExemplars (optColor opts) hicProgram
                else do
                    putStrLn "Comparing round-tripped ASTs..."
                    let loweredProgram = toCimple hicProgram
                    let originalList = Program.toList program
                    let loweredMap = Map.fromList $ Program.toList loweredProgram

                    mapM_ (checkFile loweredMap) originalList

                    putStrLn "\nDiagnostics:"
                    if null (progDiagnostics hicProgram)
                        then putStrLn "  None."
                        else mapM_ (Text.putStrLn . ("  " <>)) (progDiagnostics hicProgram)

                    putStrLn "\nInferred Constructs Statistics:"
                    if Map.null stats
                        then putStrLn $ "  No high-level constructs inferred (baseline only)."
                        else mapM_ (\(name, count) -> putStrLn $ "  " ++ name ++ ": " ++ show count) (Map.toList stats)

            -- Phase 7: Refined Solver
            _ <- runPhase PhaseRefinedSolver $ do
                putStrLn "Phase 7: Refined Type Analysis..."
                let refinedResult = inferRefined (garTypeSystem globalAnalysis) hicProgram
                let hasWork = not (Map.null (rrSolverStates refinedResult))
                if not hasWork
                    then putStrLn "  No refined types to analyze."
                    else do
                        putStrLn $ "  Graph size: " ++ show (Map.size (rrSolverStates refinedResult)) ++ " nodes"
                        putStrLn $ "  Registry size: " ++ show (Map.size (regDefinitions (rrRegistry refinedResult))) ++ " types"
                        if null (rrErrors refinedResult)
                            then putStrLn "  Refined check successful."
                            else do
                                putStrLn "  Refined check failed with errors:"
                                mapM_ (Text.putStrLn . ("    " <>)) (rrErrors refinedResult)
                                exitFailure
                return refinedResult

            return ()
  where
    handler :: E.SomeException -> IO ()
    handler e = case E.fromException e of
        Just ec -> E.throwIO (ec :: ExitCode)
        Nothing -> do
            putStr . unlines . take 20 . map (take 100) . lines $ show e
            exitFailure

showExemplars :: Bool -> Program (C.Lexeme Text) -> IO ()
showExemplars forceColor Program{..} = do
    let exemplars :: Map String (Text, Node (C.Lexeme Text))
        exemplars = Map.fromListWith (\_ old -> old) $
            [ (name, (C.sloc path n, n))
            | (path, nodes) <- Map.toList progAsts
            , node <- nodes
            , (name, n) <- collectExemplars node
            ]
    mapM_ printExemplar (Map.toList exemplars)
  where
    collectExemplars :: Node (C.Lexeme Text) -> [(String, Node (C.Lexeme Text))]
    collectExemplars n@(Fix (HicNode h)) = (nodeName h, n) : concatMap collectExemplars h
    collectExemplars (Fix (CimpleNode f)) = concatMap collectExemplars f

    printExemplar (name, (loc, node)) = do
        putStrLn $ "Exemplar for " ++ name ++ " at " ++ T.unpack loc ++ ":"
        renderDoc forceColor (ppNode node)
        putStrLn "\n"

collectStats :: Program (C.Lexeme Text) -> Map String Int
collectStats Program{..} =
    Map.unionsWith (+) . map (Map.unionsWith (+) . map countNode) $ Map.elems progAsts

countNode :: Node (C.Lexeme Text) -> Map String Int
countNode = foldFix $ \case
    CimpleNode f -> Map.unionsWith (+) f
    HicNode h    -> Map.insertWith (+) (nodeName h) 1 (Map.unionsWith (+) h)

checkFile :: Map FilePath [C.Node (C.Lexeme Text)] -> (FilePath, [C.Node (C.Lexeme Text)]) -> IO ()
checkFile loweredMap (path, nodes) = do
    let original = map (C.removeSloc . C.elideGroups) nodes
    let roundtripped = map (C.removeSloc . C.elideGroups) (loweredMap Map.! path)

    if original == roundtripped
        then return ()
        else do
            putStrLn $ "  Round-trip failed for " ++ path
            let origFile = "/tmp/hic-check-original.ast"
            let newFile = "/tmp/hic-check-roundtripped.ast"
            writeFile origFile (groom original)
            writeFile newFile (groom roundtripped)
            putStrLn "Diff:"
            callProcess "diff" ["-u", "--color=auto", origFile, newFile]
            exitFailure
