{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Strict              #-}
{-# LANGUAGE TupleSections       #-}
module Language.Cimple.Analysis.Refined.Inference.Translator
    ( translateRegistry
    , translateDescr
    , translateMember
    , translateType
    , translateType'
    , translateReturnType
    , translateTemplateIdGlobal
    , nodeToTypeInfo
    , translateStdType
    ) where

import           Control.Monad.State.Strict                              (State,
                                                                          get,
                                                                          gets,
                                                                          modify)
import           Data.Fix                                                (Fix (..),
                                                                          foldFix)
import qualified Data.Map.Strict                                         as Map
import           Data.Maybe                                              (fromMaybe)
import qualified Data.Set                                                as Set
import           Data.Text                                               (Text)
import qualified Data.Text                                               as T
import qualified Data.Text.Read                                          as TR
import           Data.Word                                               (Word32)

import           Language.Cimple                                         (Lexeme (..))
import qualified Language.Cimple                                         as C
import           Language.Cimple.Analysis.Refined.Inference.Substitution
import           Language.Cimple.Analysis.Refined.Inference.Types
import           Language.Cimple.Analysis.Refined.Inference.Utils
import           Language.Cimple.Analysis.Refined.LatticeOp
import           Language.Cimple.Analysis.Refined.Registry
import           Language.Cimple.Analysis.Refined.Types
import qualified Language.Cimple.Analysis.TypeSystem                     as TS

translateRegistry :: TS.TypeSystem -> State TranslatorState (Registry Word32)
translateRegistry ts = do
    defs <- Map.traverseWithKey (\_ d -> translateDescr d) ts
    return $ Registry defs

translateDescr :: TS.TypeDescr 'TS.Global -> State TranslatorState (TypeDefinition Word32)
translateDescr = \case
    TS.StructDescr name params members -> do
        memberDefs <- mapM translateMember members
        return $ StructDef name (map ((, Invariant) . translateTemplateIdGlobal) params) memberDefs
    TS.UnionDescr name params members -> do
        memberDefs <- mapM translateMember members
        return $ UnionDef name (map ((, Invariant) . translateTemplateIdGlobal) params) memberDefs
    TS.EnumDescr name _ ->
        return $ EnumDef name []
    TS.IntDescr name _ ->
        return $ EnumDef name []
    TS.FuncDescr name params _ _ ->
        return $ StructDef name (map ((, Invariant) . translateTemplateIdGlobal) params) []
    TS.AliasDescr name params _ ->
        return $ StructDef name (map ((, Invariant) . translateTemplateIdGlobal) params) []


nodeToTypeInfo :: TS.TypeSystem -> C.Node (Lexeme Text) -> TS.TypeInfo 'TS.Global
nodeToTypeInfo ts (Fix node) = case node of
    C.TyStd l -> TS.builtin l
    C.TyPointer t -> TS.Pointer (nodeToTypeInfo ts t)
    C.FunctionPrototype ret _ params ->
        TS.Function (nodeToTypeInfo ts ret) (map (nodeToTypeInfo ts) params)
    C.VarDecl ty _ dims ->
        let baseTy = nodeToTypeInfo ts ty
        in if null dims then baseTy else TS.Array (Just baseTy) (map (nodeToTypeInfo ts) dims)
    C.DeclSpecArray _ mSize -> maybe TS.Unconstrained (nodeToTypeInfo ts) mSize
    C.TyConst t -> TS.Const (nodeToTypeInfo ts t)
    C.TyNonnull t -> TS.Nonnull (nodeToTypeInfo ts t)
    C.TyNullable t -> TS.Nullable (nodeToTypeInfo ts t)
    C.TyOwner t -> TS.Owner (nodeToTypeInfo ts t)
    C.TyUserDefined (L _ _ t) -> case TS.lookupType t ts of
        Just (TS.AliasDescr _ _ target) -> target
        _ -> TS.TypeRef TS.UnresolvedRef (L (C.AlexPn 0 0 0) C.IdVar (TS.TIdName t)) []
    C.TyStruct l -> TS.TypeRef TS.StructRef (fmap TS.TIdName l) []
    C.TyUnion l -> TS.TypeRef TS.UnionRef (fmap TS.TIdName l) []
    C.TyFunc l -> TS.TypeRef TS.FuncRef (fmap TS.TIdName l) []
    C.VarExpr l -> TS.TypeRef TS.UnresolvedRef (fmap TS.TIdName l) []
    C.LiteralExpr C.Int l -> TS.IntLit (fmap TS.TIdName l)
    f -> TS.Unsupported (T.pack (show (Fix f)))

translateMember :: (Lexeme Text, TS.TypeInfo 'TS.Global) -> State TranslatorState (Member Word32)
translateMember (name, ty) = do
    tyId <- translateType ty
    return $ Member name tyId

-- | Translates a standard Cimple type to a Refined RigidNode.
translateType :: TS.TypeInfo 'TS.Global -> State TranslatorState Word32
translateType ty = do
    st <- get
    let ty' = TS.resolveRef (tsTypeSystem st) ty
    let TS.FlatType {..} = TS.toFlat ty'

    -- Check if this is a nominal type with an existential form
    mExistId <- case ftStructure of
        TS.TypeRefF _ name params -> do
            let baseName = TS.templateIdBaseName (C.lexemeText name)
            dtraceM ("translateType: checking nominal " ++ show baseName ++ " params=" ++ show (length params))
            case Map.lookup baseName (tsExistentials st) of
                Just existId -> do
                    -- If it's a generic application (all params are template vars),
                    -- or if it has no params, we return the existential.
                    let isGeneric = all isTemplateParam params
                    dtraceM ("translateType: found existential " ++ show existId ++ " for " ++ show baseName ++ " isGeneric=" ++ show isGeneric)
                    if isGeneric || null params then return (Just existId) else return Nothing
                Nothing -> return Nothing
        _ -> return Nothing

    case mExistId of
        Just existId -> return existId
        Nothing -> do
            let fresh = isFreshCandidate ty'
            mId <- if fresh then return Nothing else gets (Map.lookup ty' . tsCache)
            case mId of
                Just nid -> return nid
                Nothing -> do
                    nid <- gets tsNextId
                    modify $ \s -> s { tsNextId = nid + 1 }
                    -- Only cache non-void types to ensure freshness for void*
                    if not fresh then
                        modify $ \s -> s { tsCache = Map.insert ty' nid (tsCache s) }
                    else return ()
                    node <- translateType' ty'
                    dtraceM ("Registering ID " ++ show nid ++ ": " ++ show node)
                    modify (addNode nid node)
                    return nid
  where
    isTemplateParam (Fix (TS.TemplateF _)) = True
    isTemplateParam _                      = False

    isFreshCandidate = foldFix (\case
        TS.BuiltinTypeF TS.VoidTy -> True
        TS.TemplateF (TS.FT tid _) -> case tid of
            TS.TIdParam {}     -> True
            TS.TIdAnonymous {} -> True
            _                  -> False
        f -> any id f)

translateType' :: TS.TypeInfo 'TS.Global -> State TranslatorState (AnyRigidNodeF TemplateId Word32)
translateType' ty = do
    let TS.FlatType {..} = TS.toFlat ty
    dtraceM ("translateType': ftStructure=" ++ show (fmap (const ()) ftStructure))
    let quals = Quals (TS.QConst `Set.member` ftQuals)
        nullability = if TS.QNonnull `Set.member` ftQuals then QNonnull'
                      else if TS.QNullable `Set.member` ftQuals then QNullable'
                      else QUnspecified
        ownership = if TS.QOwner `Set.member` ftQuals then QOwned' else QNonOwned'
    case ftStructure of
        TS.BuiltinTypeF TS.VoidTy -> do
            nid <- gets tsNextId
            let tid = TIdParam PLocal nid (Just "T")
            modify $ \s -> s { tsNextId = nid + 1 }
            modify (addNode nid (AnyRigidNodeF (RObject (VVar tid Nothing) quals)))
            return $ AnyRigidNodeF (RObject (VVar tid Nothing) quals)

        TS.BuiltinTypeF bt -> case translateStdType bt of
            Just sbt -> return $ AnyRigidNodeF (RObject (VBuiltin sbt) quals)
            Nothing  -> return $ AnyRigidNodeF (RTerminal SConflict)

        TS.PointerF inner -> do
            let (Fix innerF) = inner
            case innerF of
                TS.FunctionF ret args -> do
                    retId <- translateReturnType ret
                    argIds <- mapM translateType args
                    return $ AnyRigidNodeF (RReference (Ptr (TargetFunction argIds retId)) nullability ownership quals)
                TS.TypeRefF TS.FuncRef name _ -> do
                    st <- get
                    case TS.lookupType (TS.templateIdBaseName (C.lexemeText name)) (tsTypeSystem st) of
                        Just (TS.FuncDescr _ _ ret args) -> do
                            retId <- translateReturnType ret
                            argIds <- mapM translateType args
                            return $ AnyRigidNodeF (RReference (Ptr (TargetFunction argIds retId)) nullability ownership quals)
                        _ -> do
                            innerId <- translateType inner
                            return $ AnyRigidNodeF (RReference (Ptr (TargetObject innerId)) nullability ownership quals)
                TS.BuiltinTypeF TS.VoidTy -> do
                    varNid <- gets tsNextId
                    let tid = TIdParam PLocal varNid (Just "T")
                    modify $ \s -> s { tsNextId = varNid + 1 }
                    modify (addNode varNid (AnyRigidNodeF (RObject (VVar tid Nothing) (Quals False))))
                    return $ AnyRigidNodeF (RReference (Ptr (TargetOpaque tid)) nullability ownership quals)
                _ -> do
                    innerId <- translateType inner
                    return $ AnyRigidNodeF (RReference (Ptr (TargetObject innerId)) nullability ownership quals)

        TS.FunctionF ret args -> do
            retId <- translateReturnType ret
            argIds <- mapM translateType args
            return $ AnyRigidNodeF (RFunction argIds retId)

        TS.ArrayF (Just inner) dims -> do
            innerId <- translateType inner
            dimIds <- mapM translateType dims
            return $ AnyRigidNodeF (RReference (Arr innerId dimIds) nullability ownership quals)

        TS.TypeRefF _ name params -> do
            paramIds <- mapM translateType params
            return $ AnyRigidNodeF (RObject (VNominal (fmap translateTemplateIdGlobal name) paramIds) quals)

        TS.TemplateF (TS.FT tid _) -> do
            return $ AnyRigidNodeF (RObject (VVar (translateTemplateIdGlobal tid) Nothing) quals)

        TS.SingletonF st val -> case translateStdType st of
            Just sbt -> return $ AnyRigidNodeF (RObject (VSingleton sbt val) quals)
            Nothing  -> return $ AnyRigidNodeF (RTerminal SConflict)

        TS.IntLitF l -> do
            let t = TS.templateIdToText (C.lexemeText l)
            case TR.decimal t of
                Right (i, _) -> return $ AnyRigidNodeF (RObject (VSingleton S32Ty i) (Quals True))
                Left _       -> return $ AnyRigidNodeF (RTerminal SConflict)

        _ -> return $ AnyRigidNodeF (RTerminal SConflict)

translateReturnType :: TS.TypeInfo 'TS.Global -> State TranslatorState (ReturnType Word32)
translateReturnType (Fix (TS.BuiltinTypeF TS.VoidTy)) = return RetVoid
translateReturnType ty = RetVal <$> translateType ty

translateTemplateIdGlobal :: TS.TemplateId 'TS.Global -> TemplateId
translateTemplateIdGlobal = \case
    TS.TIdName n      -> TIdName n
    TS.TIdParam i h   -> TIdParam PGlobal (fromIntegral i) h
    TS.TIdAnonymous h -> TIdName (fromMaybe "ANON" h)
    TS.TIdRec i       -> TIdName ("REC" <> T.pack (show i))

translateStdType :: TS.StdType -> Maybe StdType
translateStdType = \case
    TS.BoolTy    -> Just BoolTy
    TS.CharTy    -> Just CharTy
    TS.U08Ty     -> Just U08Ty
    TS.S08Ty     -> Just S08Ty
    TS.U16Ty     -> Just U16Ty
    TS.S16Ty     -> Just S16Ty
    TS.U32Ty     -> Just U32Ty
    TS.S32Ty     -> Just S32Ty
    TS.U64Ty     -> Just U64Ty
    TS.S64Ty     -> Just S64Ty
    TS.SizeTy    -> Just SizeTy
    TS.F32Ty     -> Just F32Ty
    TS.F64Ty     -> Just F64Ty
    TS.NullPtrTy -> Just NullPtrTy
    TS.VoidTy    -> Nothing
