{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData          #-}
module Language.Cimple.Analysis.Refined.Solver
    ( TypeSummary (..)
    , SolverEnv (..)
    , FilterResult (..)
    , Constraint (..)
    , solve
    , runWorklist
    ) where

import           Data.IntMap.Strict                           (IntMap)
import qualified Data.IntMap.Strict                           as IntMap
import           Data.List                                    (find)
import           Data.Map.Strict                              (Map)
import qualified Data.Map.Strict                              as Map
import           Data.Set                                     (Set)
import qualified Data.Set                                     as Set
import           Data.Text                                    (Text)
import           Data.Word                                    (Word32)
import qualified Debug.Trace                                  as Debug
import           GHC.Generics                                 (Generic)
import           Language.Cimple.Analysis.Refined.Context     (MappingContext, MappingRefinements (..),
                                                               deleteRefinement,
                                                               emptyContext,
                                                               emptyRefinements,
                                                               mrHash,
                                                               setRefinement)
import           Language.Cimple.Analysis.Refined.LatticeOp   (Polarity (..))
import           Language.Cimple.Analysis.Refined.PathContext (PathContext (..),
                                                               emptyPath)
import           Language.Cimple.Analysis.Refined.Registry    (Registry)
import           Language.Cimple.Analysis.Refined.State       (ProductState (..))
import           Language.Cimple.Analysis.Refined.Transition  (TransitionEnv (..),
                                                               isRefinable,
                                                               step,
                                                               variableKey)
import           Language.Cimple.Analysis.Refined.Types       (AnyRigidNodeF (..),
                                                               ObjectStructure (..),
                                                               RigidNodeF (..),
                                                               TemplateId,
                                                               TerminalNode (..))

debugging :: Bool
debugging = False

dtrace :: String -> a -> a
dtrace msg x = if debugging then Debug.trace msg x else x

-- | A compact representation of a solved SCC's refined type information.
-- Used to isolate SCCs and enable incremental compilation.
data TypeSummary = TypeSummary
    { tsExportedTypes :: Map Text (AnyRigidNodeF TemplateId Int)
    -- ^ Map of names to their canonical refined type structure IDs.
    }
    deriving (Show, Eq, Ord, Generic)

-- | Environment for the project-wide Refined Solver.
data SolverEnv = SolverEnv
    { seSummaries :: Map Text TypeSummary
    -- ^ Cached summaries from already-solved SCCs.
    }
    deriving (Show, Eq, Ord, Generic)

-- | Result of the Refinement Filter (linear symbolic pass).
-- Identifies which fragments of the project require the rigorous graph solver.
data FilterResult = FilterResult
    { frRequiresRigorousSolver :: Bool
    -- ^ True if the code contains Refinement Triggers (Existentials, Tagged Unions).
    , frHotspots               :: [Text]
    -- ^ Names of functions/structs identified as hotspots.
    }
    deriving (Show, Eq, Ord, Generic)

-- | A subtyping constraint to be solved.
data Constraint
    = CSubtype Word32 Word32 Polarity MappingContext PathContext Int Int
    | CInherit Word32 Word32 -- ^ Left inherits refinements from Right (one-way PMeet)
    deriving (Show, Eq, Ord, Generic)

-- | Executes the project-wide fixpoint solver on a set of constraints.
solve :: Registry Word32
      -> Map Word32 (AnyRigidNodeF TemplateId Word32)
      -> [Constraint]
      -> (Word32, Word32, Word32, Word32) -- ^ (Bottom, Any, Conflict, STerminal) IDs
      -> (Bool, MappingRefinements)
solve registry nodes constraints terminals =
    let initialWorklist = Set.fromList [ ProductState l r pol False gamma dL dR Nothing | CSubtype l r pol gamma _ dL dR <- constraints ]
                       <> Set.fromList [ ProductState l r PMeet True emptyContext 0 0 Nothing | CInherit l r <- constraints ]
    in runWorklist registry nodes constraints terminals emptyRefinements initialWorklist Set.empty

terminalToId :: TerminalNode a -> (Word32, Word32, Word32, Word32) -> Maybe Word32
terminalToId term (bot, any', conflict, _) = case term of
    SBottom     -> Just bot
    SAny        -> Just any'
    SConflict   -> Just conflict
    STerminal{} -> Nothing

-- | Core worklist loop for the Product Automaton.
-- Only moves UP the lattice. Restarts on refinement changes to ensure consistency.
runWorklist :: Registry Word32
            -> Map Word32 (AnyRigidNodeF TemplateId Word32)
            -> [Constraint]
            -> (Word32, Word32, Word32, Word32)
            -> MappingRefinements
            -> Set ProductState
            -> Set ProductState
            -> (Bool, MappingRefinements)
runWorklist registry nodes constraints terminals !refs worklist visited
    | Set.null worklist = (True, refs)
    | otherwise =
        let (ps, rest) = Set.deleteFindMin worklist
        in if ps `Set.member` visited
           then runWorklist registry nodes constraints terminals refs rest visited
           else dtrace ("solve step: " ++ show ps) $
               let isMatch = \case
                       CSubtype l' r' pol' gamma' _ dL' dR' ->
                           psNodeL ps == l' && psNodeR ps == r' && psPolarity ps == pol' &&
                           not (psOneWay ps) &&
                           psGamma ps == gamma' && psDepthL ps == dL' && psDepthR ps == dR'
                       CInherit l' r' ->
                           psNodeL ps == l' && psNodeR ps == r' && psPolarity ps == PMeet &&
                           psOneWay ps &&
                           psGamma ps == emptyContext && psDepthL ps == 0 && psDepthR ps == 0
                   mCtx = find isMatch constraints
                   pathCtx = case mCtx of
                       Just (CSubtype _ _ _ _ c _ _) -> c
                       _                             -> PathContext Map.empty Map.empty
                   (refineL, refineR) = (True, not (psOneWay ps))
                   env = TransitionEnv nodes registry (psPolarity ps) pathCtx emptyPath terminals refineL refineR

                   -- Special handling for CInherit: Don't refine psNodeR
                   (result, !newRefs) = step env ps refs

               in dtrace ("solve step: " ++ show ps ++ " -> res: " ++ show result) $ case result of
                   AnyRigidNodeF (RTerminal SConflict) -> (False, refs)
                   AnyRigidNodeF (RTerminal term) | Just termId <- terminalToId term terminals ->
                       let refsParent = case psParentVar ps of
                               Just (d, tid) | psPolarity ps == PMeet ->
                                   dtrace ("Refining Parent " ++ show tid ++ " at depth " ++ show d ++ " to " ++ show termId) $
                                   setRefinement (variableKey nodes d tid) termId newRefs
                               _ -> newRefs
                           refsL = case Map.lookup (psNodeL ps) nodes of
                               Just (AnyRigidNodeF (RObject (VVar tid _) _)) | isRefinable tid && psPolarity ps == PMeet && refineL ->
                                   dtrace ("Refining L " ++ show tid ++ " to " ++ show termId) $
                                   setRefinement (variableKey nodes (psDepthL ps) tid) termId refsParent
                               _ -> refsParent
                           refsR = case Map.lookup (psNodeR ps) nodes of
                               Just (AnyRigidNodeF (RObject (VVar tid _) _)) | isRefinable tid && psPolarity ps == PMeet && refineR ->
                                   dtrace ("Refining R " ++ show tid ++ " to " ++ show termId) $
                                   setRefinement (variableKey nodes (psDepthR ps) tid) termId refsL
                               _ -> refsL
                       in if mrHash refsR /= mrHash refs
                          then let topLevel = Set.fromList [ ProductState l' r' pol' False gamma' dL' dR' Nothing | CSubtype l' r' pol' gamma' _ dL' dR' <- constraints ]
                                           <> Set.fromList [ ProductState l' r' PMeet True emptyContext 0 0 Nothing | CInherit l' r' <- constraints ]
                                   newWorklist = Set.unions [rest, topLevel, Set.fromList (foldMap (:[]) (AnyRigidNodeF (RTerminal term)))]
                               in runWorklist registry nodes constraints terminals refsR newWorklist Set.empty
                          else let children = Set.fromList $ foldMap (:[]) (AnyRigidNodeF (RTerminal term))
                                   newWorklist = Set.union rest children
                               in runWorklist registry nodes constraints terminals refsR newWorklist (Set.insert ps visited)
                   AnyRigidNodeF n ->
                       if mrHash newRefs /= mrHash refs
                       then -- Refinements changed! Re-add all top-level constraints and CLEAR visited set.
                            let topLevel = Set.fromList [ ProductState l' r' pol' False gamma' dL' dR' Nothing | CSubtype l' r' pol' gamma' _ dL' dR' <- constraints ]
                                        <> Set.fromList [ ProductState l' r' PMeet True emptyContext 0 0 Nothing | CInherit l' r' <- constraints ]
                                newWorklist = Set.unions [rest, topLevel, Set.fromList (foldMap (:[]) n)]
                            in runWorklist registry nodes constraints terminals newRefs newWorklist Set.empty
                       else
                           let children = Set.fromList $ foldMap (:[]) n
                               newWorklist = Set.union rest children
                           in runWorklist registry nodes constraints terminals refs newWorklist (Set.insert ps visited)
