{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Strict              #-}
module Language.Cimple.Analysis.Refined.Inference.Types
    ( RefinedResult (..)
    , TaggedUnionInfo (..)
    , TranslatorState (..)
    , emptyTranslatorState
    , addConstraint
    , addConstraintCoerced
    , addNode
    , addFunction
    , addVar
    , addTaggedUnion
    ) where

import           Data.Aeson                                       (ToJSON (..),
                                                                   object, (.=))
import           Data.Map.Strict                                  (Map)
import qualified Data.Map.Strict                                  as Map
import           Data.Text                                        (Text)
import           Data.Word                                        (Word32)
import           GHC.Generics                                     (Generic)

import           Language.Cimple.Analysis.Refined.Context
import           Language.Cimple.Analysis.Refined.LatticeOp
import           Language.Cimple.Analysis.Refined.PathContext
import           Language.Cimple.Analysis.Refined.Registry
import           Language.Cimple.Analysis.Refined.Solver          (Constraint (..))
import           Language.Cimple.Analysis.Refined.State
import           Language.Cimple.Analysis.Refined.Transition
import           Language.Cimple.Analysis.Refined.Types
import qualified Language.Cimple.Analysis.TypeSystem              as TS

import           Language.Cimple.Analysis.Refined.Inference.Utils

data RefinedResult = RefinedResult
    { rrHotspots     :: [Text]
    , rrSolverStates :: Map Word32 (AnyRigidNodeF TemplateId Word32)
    , rrRegistry     :: Registry Word32
    , rrSolved       :: Bool
    , rrErrors       :: [Text]
    } deriving (Show)

instance ToJSON RefinedResult where
    toJSON RefinedResult{..} = object [ "hotspots" .= rrHotspots, "solved" .= rrSolved, "errors" .= rrErrors ]

data TaggedUnionInfo = TaggedUnionInfo
    { tuiTagField   :: Text
    , tuiUnionField :: Text
    , tuiMembers    :: Map Text Text -- ^ EnumVal -> MemberName
    } deriving (Show)

-- | State for the refinement translator.
data TranslatorState = TranslatorState
    { tsNextId         :: Word32
    , tsNodes          :: Map Word32 (AnyRigidNodeF TemplateId Word32)
    , tsCache          :: Map (TS.TypeInfo 'TS.Global) Word32
    , tsConstraints    :: [Constraint]
    , tsCurrentPath    :: SymbolicPath
    , tsVars           :: Map Text Word32
    , tsFunctions      :: Map Text Word32
    , tsTypeSystem     :: TS.TypeSystem
    , tsTaggedUnions   :: Map Text TaggedUnionInfo
    , tsArrayInstances :: Map (Word32, Integer) Word32
    , tsExistentials   :: Map Text Word32
    , tsCurrentReturn  :: Maybe Word32
    , tsErrors         :: [Text]
    , tsSubstCache     :: Map Word32 Word32
    }

-- Helper functions for record updates to assist GHC type inference
addConstraint :: PathContext -> Word32 -> Word32 -> TranslatorState -> TranslatorState
addConstraint ctx l r s = dtrace ("addConstraint: " ++ show l ++ " <: " ++ show r) $ s { tsConstraints = CSubtype l r PMeet emptyContext ctx 0 0 : tsConstraints s }

-- | Safe numeric coercion for built-ins.
-- If both types are numeric, we trust the standard TypeSystem and don't emit a refined constraint.
addConstraintCoerced :: PathContext -> Word32 -> Word32 -> TranslatorState -> TranslatorState
addConstraintCoerced ctx l r s =
    let isNumeric nid = case Map.lookup nid (tsNodes s) of
            Just (AnyRigidNodeF (RObject (VBuiltin bt) _)) -> bt /= NullPtrTy
            Just (AnyRigidNodeF (RObject (VSingleton bt _) _)) -> bt /= NullPtrTy
            _ -> False
    in if isNumeric l && isNumeric r
       then s -- Swallow pure numeric constraints
       else addConstraint ctx l r s

addNode :: Word32 -> AnyRigidNodeF TemplateId Word32 -> TranslatorState -> TranslatorState
addNode nid node s = s { tsNodes = Map.insert nid node (tsNodes s) }

addFunction :: Text -> Word32 -> TranslatorState -> TranslatorState
addFunction name nid s = s { tsFunctions = Map.insert name nid (tsFunctions s) }

addVar :: Text -> Word32 -> TranslatorState -> TranslatorState
addVar name nid s = s { tsVars = Map.insert name nid (tsVars s) }

addTaggedUnion :: Text -> TaggedUnionInfo -> TranslatorState -> TranslatorState
addTaggedUnion name tu s = s { tsTaggedUnions = Map.insert name tu (tsTaggedUnions s) }

emptyTranslatorState :: TS.TypeSystem -> TranslatorState
emptyTranslatorState ts = TranslatorState
    { tsNextId = 3
    , tsNodes  = Map.fromList
        [ (0, AnyRigidNodeF (RTerminal SBottom))
        , (1, AnyRigidNodeF (RTerminal SAny))
        , (2, AnyRigidNodeF (RTerminal SConflict))
        ]
    , tsCache  = Map.empty
    , tsConstraints = []
    , tsCurrentPath = emptyPath
    , tsVars = Map.empty
    , tsFunctions = Map.empty
    , tsTypeSystem = ts
    , tsTaggedUnions = Map.empty
    , tsArrayInstances = Map.empty
    , tsExistentials   = Map.empty
    , tsCurrentReturn  = Nothing
    , tsErrors = []
    , tsSubstCache = Map.empty
    }
