{-# LANGUAGE DeriveFoldable    #-}
{-# LANGUAGE DeriveFunctor     #-}
{-# LANGUAGE DeriveGeneric     #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE LambdaCase        #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
module Language.Cimple.Analysis.TypeSystem.Transition
    ( Polarity (..)
    , ProductState (..)
    , RigidNodeF (..)
    , ValueStructure (..)
    , SpecialNode (..)
    , stepTransition
    , toRigid
    , fromRigid
    ) where

import           Data.Fix                                          (Fix (..))
import           Data.Functor                                      (void)
import           Data.Set                                          (Set)
import qualified Data.Set                                          as Set
import           Data.Text                                         (Text)
import qualified Data.Text                                         as Text
import qualified Debug.Trace                                       as Debug
import           GHC.Generics                                      (Generic)
import           Language.Cimple                                   (Lexeme (..))
import           Language.Cimple.Analysis.TypeSystem               (FlatType (..),
                                                                    FullTemplateF (..),
                                                                    Qualifier (..),
                                                                    StdType (..),
                                                                    TemplateId (..),
                                                                    TypeInfo,
                                                                    TypeInfoF (..),
                                                                    TypeRef (..),
                                                                    isInt,
                                                                    toFlat)
import           Language.Cimple.Analysis.TypeSystem.Qualification (Constness (..),
                                                                    Nullability (..),
                                                                    Ownership (..),
                                                                    QualState (..),
                                                                    allowCovariance,
                                                                    fromQuals,
                                                                    stepQual,
                                                                    toQuals)
import           Test.QuickCheck                                   (Arbitrary (..),
                                                                    arbitraryBoundedEnum,
                                                                    genericShrink,
                                                                    oneof)

debugging :: Bool
debugging = False

dtrace :: String -> a -> a
dtrace msg x = if debugging then Debug.trace msg x else x

-- | Polarity of the lattice operation (Join/Upper Bound or Meet/Lower Bound).
data Polarity = PJoin | PMeet deriving (Show, Eq, Ord, Generic, Bounded, Enum)

-- | The state of the product automaton.
data ProductState = ProductState
    { psPolarity   :: Polarity
    , psQualL      :: QualState
    , psQualR      :: QualState
    , psForceConst :: Bool
    } deriving (Show, Eq, Ord, Generic)

-- | A canonicalized type node with attributes.
-- Enforces correct-by-construction property: attributes only where valid.
data RigidNodeF tid a
    = RFunction a [a] Constness (Maybe (Lexeme tid)) -- Ret type 'a' must not be another RFunction
    | RValue (ValueStructure tid a) Constness (Maybe (Lexeme tid))
    | RSpecial SpecialNode
    deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable)

data ValueStructure tid a
    = VBuiltin StdType
    | VPointer a Nullability Ownership
    | VTemplate (FullTemplateF tid a) Nullability Ownership
    | VTypeRef TypeRef (Lexeme tid) [a]
    | VArray (Maybe a) [a]
    | VSingleton StdType Integer
    | VExternal (Lexeme tid)
    | VIntLit (Lexeme tid)
    | VNameLit (Lexeme tid)
    | VEnumMem (Lexeme tid)
    | VVarArg
    deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable)

data SpecialNode = SUnconstrained | SConflict
    deriving (Show, Eq, Ord, Generic, Bounded, Enum)

instance Arbitrary Polarity where
    arbitrary = arbitraryBoundedEnum
    shrink = genericShrink

instance Arbitrary ProductState where
    arbitrary = ProductState <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary
    shrink = genericShrink

instance Arbitrary SpecialNode where
    arbitrary = arbitraryBoundedEnum
    shrink = genericShrink

instance (Arbitrary tid, Arbitrary a) => Arbitrary (RigidNodeF tid a) where
    arbitrary = oneof
        [ RFunction <$> arbitrary <*> arbitrary <*> arbitrary <*> arbitrary
        , RValue <$> arbitrary <*> arbitrary <*> arbitrary
        , RSpecial <$> arbitrary
        ]
    shrink = genericShrink

instance (Arbitrary tid, Arbitrary a) => Arbitrary (ValueStructure tid a) where
    arbitrary = oneof
        [ VBuiltin <$> arbitrary
        , VPointer <$> arbitrary <*> arbitrary <*> arbitrary
        , VTemplate <$> arbitrary <*> arbitrary <*> arbitrary
        , VTypeRef <$> arbitrary <*> arbitrary <*> arbitrary
        , VArray <$> arbitrary <*> arbitrary
        , VSingleton <$> arbitrary <*> arbitrary
        , VExternal <$> arbitrary
        , VIntLit <$> arbitrary
        , VNameLit <$> arbitrary
        , VEnumMem <$> arbitrary
        ]
    shrink = genericShrink


-- | Projects a TypeInfo into its RigidNode form (one level).
toRigid :: TypeInfo p -> Maybe (RigidNodeF (TemplateId p) (TypeInfo p))
toRigid ty =
    let FlatType structure quals size = toFlat ty
        (nullability, ownership, constness) = toQuals quals
    in case structure of
        UnconstrainedF -> Just $ RSpecial SUnconstrained
        ConflictF      -> Just $ RSpecial SConflict
        FunctionF r ps -> Just $ RFunction r ps constness size
        UnsupportedF _ -> Just $ RSpecial SConflict
        _ -> RValue <$> toValueStructure structure nullability ownership <*> pure constness <*> pure size

toValueStructure :: TypeInfoF tid a -> Nullability -> Ownership -> Maybe (ValueStructure tid a)
toValueStructure structure n o = case structure of
    TypeRefF r l args -> Just $ VTypeRef r l args
    PointerF a        -> Just $ VPointer a n o
    BuiltinTypeF s    -> Just $ VBuiltin s
    ExternalTypeF l   -> Just $ VExternal l
    ArrayF m ds       -> Just $ VArray m ds
    TemplateF ft      -> Just $ VTemplate ft n o
    SingletonF s i    -> Just $ VSingleton s i
    IntLitF l         -> Just $ VIntLit l
    NameLitF l        -> Just $ VNameLit l
    EnumMemF l        -> Just $ VEnumMem l
    VarArgF           -> Just $ VVarArg
    _                 -> Nothing

-- | Reconstructs a TypeInfo from a RigidNode.
fromRigid :: (a -> TypeInfo p) -> RigidNodeF (TemplateId p) a -> TypeInfo p
fromRigid f = \case
    RFunction r ps c s -> fromValueNode' f r ps c s
    RValue v c s   -> fromValueNode f v c s
    RSpecial s     -> fromSpecialNode s

fromValueNode' :: (a -> TypeInfo p) -> a -> [a] -> Constness -> Maybe (Lexeme (TemplateId p)) -> TypeInfo p
fromValueNode' f r ps c s =
    let base = Fix (FunctionF (f r) (map f ps))
        qs = fromQuals QUnspecified QNonOwned' c
        withQuals = if Set.null qs then base else Fix (QualifiedF qs base)
    in maybe withQuals (Fix . SizedF withQuals) s

fromValueNode :: (a -> TypeInfo p) -> ValueStructure (TemplateId p) a -> Constness -> Maybe (Lexeme (TemplateId p)) -> TypeInfo p
fromValueNode f v c s =
    let (base, n, o) = fromValueStructure f v
        qs = fromQuals n o c
        withQuals = if Set.null qs then base else Fix (QualifiedF qs base)
    in maybe withQuals (Fix . SizedF withQuals) s

fromValueStructure :: (a -> TypeInfo p) -> ValueStructure (TemplateId p) a -> (TypeInfo p, Nullability, Ownership)
fromValueStructure f = \case
    VBuiltin s       -> (Fix (BuiltinTypeF s), QUnspecified, QNonOwned')
    VPointer a n o   -> (Fix (PointerF (f a)), n, o)
    VTemplate ft n o -> (Fix (TemplateF (fmap f ft)), n, o)
    VTypeRef r l as  -> (Fix (TypeRefF r l (map f as)), QUnspecified, QNonOwned')
    VArray m ds      -> (Fix (ArrayF (fmap f m) (map f ds)), QUnspecified, QNonOwned')
    VSingleton s i   -> (Fix (SingletonF s i), QUnspecified, QNonOwned')
    VExternal l      -> (Fix (ExternalTypeF l), QUnspecified, QNonOwned')
    VIntLit l        -> (Fix (IntLitF l), QUnspecified, QNonOwned')
    VNameLit l       -> (Fix (NameLitF l), QUnspecified, QNonOwned')
    VEnumMem l       -> (Fix (EnumMemF l), QUnspecified, QNonOwned')
    VVarArg          -> (Fix VarArgF, QUnspecified, QNonOwned')

fromSpecialNode :: SpecialNode -> TypeInfo p
fromSpecialNode = \case
    SUnconstrained -> Fix UnconstrainedF
    SConflict      -> Fix ConflictF
-- | The core transition function for the product automaton.
stepTransition :: (Eq a, Show a)
               => ProductState
               -> (a -> Maybe (RigidNodeF (TemplateId p) a)) -- ^ Rigid node lookup
               -> (a -> (Nullability, Ownership, Constness)) -- ^ Lookup quals for children
               -> (a, a) -- ^ (bot, top)
               -> RigidNodeF (TemplateId p) a
               -> RigidNodeF (TemplateId p) a
               -> RigidNodeF (TemplateId p) (a, a, ProductState)
stepTransition ps lookupNode getQuals terminals nL nR =
    let res = step ps lookupNode getQuals terminals nL nR
    in dtrace ("stepTransition: ps=" ++ show ps ++ " nL=" ++ show (void nL) ++ " nR=" ++ show (void nR) ++ " -> res=" ++ show (void res)) res

step :: (Eq a, Show a)
     => ProductState
     -> (a -> Maybe (RigidNodeF (TemplateId p) a))
     -> (a -> (Nullability, Ownership, Constness))
     -> (a, a)
     -> RigidNodeF (TemplateId p) a
     -> RigidNodeF (TemplateId p) a
     -> RigidNodeF (TemplateId p) (a, a, ProductState)
step ps@ProductState{..} lookupNode getQuals terminals nL nR =
    case (nL, nR) of
        -- 1. Atomic Merge (Units and Zeros)
        (RSpecial SUnconstrained, _) -> case psPolarity of
            PJoin -> fmap (\r -> (fst terminals, r, ps { psQualL = QualTop })) nR
            PMeet -> RSpecial SUnconstrained
        (_, RSpecial SUnconstrained) -> case psPolarity of
            PJoin -> fmap (\l -> (l, fst terminals, ps { psQualR = QualTop })) nL
            PMeet -> RSpecial SUnconstrained

        (RSpecial SConflict, _) -> case psPolarity of
            PJoin -> RSpecial SConflict
            PMeet -> fmap (\r -> (snd terminals, r, ps { psQualL = QualTop })) nR
        (_, RSpecial SConflict) -> case psPolarity of
            PJoin -> RSpecial SConflict
            PMeet -> fmap (\l -> (l, snd terminals, ps { psQualR = QualTop })) nL

        -- 2. Value vs Value
        (RValue vL cL sL, RValue vR cR sR) ->
            case stepValueStructure ps lookupNode getQuals terminals cL cR vL vR of
                Just (resV, _, _) ->
                    let resC = case psPolarity of
                            PJoin -> max cL cR
                            PMeet -> min cL cR
                        resC' = if psForceConst then QConst' else resC
                        resS = if sL == sR then sL else Nothing

                        invariance = not (psForceConst || (allowCovariance psQualL && allowCovariance psQualR))
                        isLevel1 = case psQualL of { QualLevel1Const -> True; QualLevel1Mutable -> True; _ -> False }
                                || case psQualR of { QualLevel1Const -> True; QualLevel1Mutable -> True; _ -> False }
                        qualConflict = invariance && not isLevel1 && cL /= cR
                    in if qualConflict then zero ps
                       else RValue resV resC' resS
                Nothing -> stepMismatched ps lookupNode getQuals terminals cL cR nL nR
        -- 3. Function vs Function
        (RFunction rL pL cL sL, RFunction rR pR cR sR) ->
            if length pL /= length pR then zero ps
            else
                let resC = case psPolarity of
                        PJoin -> max cL cR
                        PMeet -> min cL cR
                    resC' = if psForceConst then QConst' else resC
                    resS = if sL == sR then sL else Nothing

                    invariance = not (psForceConst || (allowCovariance psQualL && allowCovariance psQualR))
                    qualConflict = invariance && cL /= cR

                    psRes = ps { psQualL = QualTop, psQualR = QualTop, psForceConst = False }
                    psContra = psRes { psPolarity = flipPol psPolarity }
                in if qualConflict then zero ps
                   else RFunction (rL, rR, psRes) (zipWith (\l r -> (l, r, psContra)) pL pR) resC' resS

        -- 4. Mismatched constructors (Cross-joins, etc.)
        (sL, sR) -> stepMismatched ps lookupNode getQuals terminals QMutable' QMutable' sL sR

stepMismatched :: (Eq a, Show a)
               => ProductState
               -> (a -> Maybe (RigidNodeF (TemplateId p) a))
               -> (a -> (Nullability, Ownership, Constness))
               -> (a, a)
               -> Constness -> Constness
               -> RigidNodeF (TemplateId p) a
               -> RigidNodeF (TemplateId p) a
               -> RigidNodeF (TemplateId p) (a, a, ProductState)
stepMismatched ps@ProductState{..} lookupNode _ terminals@(bot, _) cL cR nL nR =
    let invariance = not (psForceConst || (allowCovariance psQualL && allowCovariance psQualR))
        isLevel1 = case psQualL of { QualLevel1Const -> True; QualLevel1Mutable -> True; _ -> False }
                || case psQualR of { QualLevel1Const -> True; QualLevel1Mutable -> True; _ -> False }
        qualConflict = invariance && not isLevel1 && cL /= cR
    in case (nL, nR) of
        (RValue (VPointer tL nullL oL) _ sL, RValue (VArray (Just tR) dsR) _ sR) ->
            let (resState, canJoin) = getTargetState ps lookupNode terminals cL cR tL tR
                resN = case psPolarity of { PJoin -> max nullL QUnspecified; PMeet -> min nullL QUnspecified }
                resO = case psPolarity of { PJoin -> max oL QNonOwned'; PMeet -> min oL QNonOwned' }
                resC = case psPolarity of { PJoin -> max cL cR; PMeet -> min cL cR }
                resC' = if psForceConst then QConst' else resC
                resS = if sL == sR then sL else Nothing
            in if canJoin && not qualConflict then case psPolarity of
                PJoin -> RValue (VPointer (tL, tR, resState) resN resO) resC' resS
                PMeet -> RValue (VArray (Just (tL, tR, resState)) (map (\r -> (bot, r, ps { psQualL = QualTop, psQualR = QualTop })) dsR)) resC' resS
            else zero ps
        (RValue (VArray (Just tL) dsL) _ sL, RValue (VPointer tR nullR oR) _ sR) ->
            let (resState, canJoin) = getTargetState ps lookupNode terminals cL cR tL tR
                resN = case psPolarity of { PJoin -> max QUnspecified nullR; PMeet -> min QUnspecified nullR }
                resO = case psPolarity of { PJoin -> max QNonOwned' oR; PMeet -> min QNonOwned' oR }
                resC = case psPolarity of { PJoin -> max cL cR; PMeet -> min cL cR }
                resC' = if psForceConst then QConst' else resC
                resS = if sL == sR then sL else Nothing
            in if canJoin && not qualConflict then case psPolarity of
                PJoin -> RValue (VPointer (tL, tR, resState) resN resO) resC' resS
                PMeet -> RValue (VArray (Just (tL, tR, resState)) (map (\l -> (l, bot, ps { psQualL = QualTop, psQualR = QualTop })) dsL)) resC' resS
            else zero ps

        -- nullptr_t vs Pointer/Array
        (RValue vL _ _, RValue (VPointer tR nullR oR) _ _) | isNull vL ->
            case psPolarity of
                PJoin -> if invariance && not isLevel1 then zero ps
                         else let (resState, _) = getTargetState ps lookupNode terminals cL cR bot tR
                              in RValue (VPointer (bot, tR, resState) nullR oR) cR Nothing
                PMeet -> if invariance && not isLevel1 then zero ps
                         else RValue (fmap (\x -> (x, x, ps)) vL) cL Nothing
        (RValue (VPointer tL nullL oL) _ _, RValue vR _ _) | isNull vR ->
            case psPolarity of
                PJoin -> if invariance && not isLevel1 then zero ps
                         else let (resState, _) = getTargetState ps lookupNode terminals cL cR tL bot
                              in RValue (VPointer (tL, bot, resState) nullL oL) cL Nothing
                PMeet -> if invariance && not isLevel1 then zero ps
                         else RValue (fmap (\x -> (x, x, ps)) vR) cR Nothing

        (RValue vL _ _, RValue (VArray (Just tR) dsR) _ _) | isNull vL ->
            case psPolarity of
                PJoin -> if invariance && not isLevel1 then zero ps
                         else let (resState, _) = getTargetState ps lookupNode terminals cL cR bot tR
                              in RValue (VArray (Just (bot, tR, resState)) (map (\r -> (bot, r, ps { psQualL = QualTop, psQualR = QualTop })) dsR)) cR Nothing
                PMeet -> if invariance && not isLevel1 then zero ps
                         else RValue (fmap (\x -> (x, x, ps)) vL) cL Nothing
        (RValue (VArray (Just tL) dsL) _ _, RValue vR _ _) | isNull vR ->
            case psPolarity of
                PJoin -> if invariance && not isLevel1 then zero ps
                         else let (resState, _) = getTargetState ps lookupNode terminals cL cR tL bot
                              in RValue (VArray (Just (tL, bot, resState)) (map (\l -> (l, bot, ps { psQualL = QualTop, psQualR = QualTop })) dsL)) cL Nothing
                PMeet -> if invariance && not isLevel1 then zero ps
                         else RValue (fmap (\x -> (x, x, ps)) vR) cR Nothing

        _ -> zero ps

isNull :: ValueStructure tid a -> Bool
isNull (VBuiltin NullPtrTy)     = True
isNull (VSingleton NullPtrTy _) = True
isNull _                        = False

stepValueStructure :: (Eq a, Show a)
                  => ProductState
                  -> (a -> Maybe (RigidNodeF (TemplateId p) a))
                  -> (a -> (Nullability, Ownership, Constness))
                  -> (a, a)
                  -> Constness -> Constness
                  -> ValueStructure (TemplateId p) a
                  -> ValueStructure (TemplateId p) a
                  -> Maybe (ValueStructure (TemplateId p) (a, a, ProductState), Nullability, Ownership)
stepValueStructure ps lookupNode getQuals terminals@(_, top) cL cR sL sR =
    case (sL, sR) of
        (VBuiltin b1, VBuiltin b2)
            | b1 == b2 -> Just (VBuiltin b1, QUnspecified, QNonOwned')
            | isInt b1 && isInt b2 ->
                let m = case psPolarity ps of
                             PJoin -> if b1 > b2 then b1 else b2
                             PMeet -> if b1 < b2 then b1 else b2
                    invariance = not (psForceConst ps || (allowCovariance (psQualL ps) && allowCovariance (psQualR ps)))
                in if invariance && (b1 /= b2) then Nothing
                   else Just (VBuiltin m, QUnspecified, QNonOwned')

        (VSingleton b2 v2, VBuiltin b1) -> mergeSingleton ps getQuals b2 v2 b1
        (VBuiltin b1, VSingleton b2 v2) ->
            case mergeSingleton ps { psQualL = psQualR ps, psQualR = psQualL ps } getQuals b2 v2 b1 of
                Just (res, n, o) -> Just (fmap (\(r', l', p) -> (l', r', p { psQualL = psQualR p, psQualR = psQualL p })) res, n, o)
                Nothing -> Nothing

        (VSingleton b1 v1, VSingleton b2 v2)
            | b1 == b2 && v1 == v2 -> Just (VSingleton b1 v1, QUnspecified, QNonOwned')
            | isInt b1 && isInt b2 ->
                let invariance = not (psForceConst ps || (allowCovariance (psQualL ps) && allowCovariance (psQualR ps)))
                in case psPolarity ps of
                    PJoin ->
                        let m = if b1 > b2 then b1 else b2
                        in if invariance && b1 /= b2 then Nothing
                           else if v1 == v2 then Just (VSingleton m v1, QUnspecified, QNonOwned')
                           else if invariance && b1 == b2 then Nothing
                           else Just (VBuiltin m, QUnspecified, QNonOwned')
                    PMeet ->
                        if v1 == v2 then
                            let m = if b1 < b2 then b1 else b2
                            in if invariance && b1 /= b2 then Nothing
                               else Just (VSingleton m v1, QUnspecified, QNonOwned')
                        else Nothing
            | psPolarity ps == PJoin && b1 == b2 ->
                let invariance = not (psForceConst ps || (allowCovariance (psQualL ps) && allowCovariance (psQualR ps)))
                in if invariance && b1 /= NullPtrTy then Nothing
                   else Just (VBuiltin b1, QUnspecified, QNonOwned')
            | otherwise -> Nothing

        (VPointer tL nL oL, VPointer tR nR oR) ->
            let (resState, canJoin) = getTargetState ps lookupNode terminals cL cR tL tR
                resN = case psPolarity ps of { PJoin -> max nL nR; PMeet -> min nL nR }
                resO = case psPolarity ps of { PJoin -> max oL oR; PMeet -> min oL oR }
            in if canJoin then Just (VPointer (tL, tR, resState) resN resO, QUnspecified, QNonOwned')
               else Nothing

        (VArray (Just tL) dsL, VArray (Just tR) dsR) ->
            let (resState, canJoin) = getTargetState ps lookupNode terminals cL cR tL tR
            in if not canJoin then Nothing
               else case psPolarity ps of
                PJoin ->
                    let resDs = if length dsL == length dsR
                                     then zipWith (\l r -> (l, r, ps { psQualL = QualTop, psQualR = QualTop })) dsL dsR
                                     else []
                    in Just (VArray (Just (tL, tR, resState)) resDs, QUnspecified, QNonOwned')
                PMeet ->
                    let resDs = if null dsL then map (\r -> (top, r, ps { psQualL = QualTop, psQualR = QualTop })) dsR
                                else if null dsR then map (\l -> (l, top, ps { psQualL = QualTop, psQualR = QualTop })) dsL
                                else if length dsL == length dsR
                                then zipWith (\l r -> (l, r, ps { psQualL = QualTop, psQualR = QualTop })) dsL dsR
                                else []
                    in if null dsL || null dsR || length dsL == length dsR
                       then Just (VArray (Just (tL, tR, resState)) resDs, QUnspecified, QNonOwned')
                       else Nothing

        (l, r) | void l == void r ->
            Just (fmap (\(a, b) -> (a, b, ps { psForceConst = False })) (zipValueStructures l r), QUnspecified, QNonOwned')

        _ -> Nothing

mergeSingleton :: ProductState
               -> (a -> (Nullability, Ownership, Constness))
               -> StdType -> Integer -> StdType
               -> Maybe (ValueStructure tid (a, a, ProductState), Nullability, Ownership)
mergeSingleton ProductState{..} _ b1 v1 b2 =
    if b1 == b2 || (isInt b1 && isInt b2)
    then case psPolarity of
        PJoin ->
            let m = if b1 > b2 then b1 else b2
                invariance = not (psForceConst || (allowCovariance psQualL && allowCovariance psQualR))
                isIdentityWidening = b1 == NullPtrTy && b2 == NullPtrTy
            in if invariance && not isIdentityWidening then Nothing else Just (VBuiltin m, QUnspecified, QNonOwned')
        PMeet ->
            let m = if b1 < b2 then b1 else b2
                isIdentityNarrowing = b1 == NullPtrTy && b2 == NullPtrTy
                invariance = not (allowCovariance psQualR) && not isIdentityNarrowing
            in if invariance && b1 /= b2 then Nothing else Just (VSingleton m v1, QUnspecified, QNonOwned')
    else Nothing

zipValueStructures :: ValueStructure tid a -> ValueStructure tid b -> ValueStructure tid (a, b)
zipValueStructures (VBuiltin s) (VBuiltin _) = VBuiltin s
zipValueStructures (VPointer a n o) (VPointer b _ _) = VPointer (a, b) n o
zipValueStructures (VTemplate ft n o) (VTemplate ft2 _ _) = VTemplate (zipFT ft ft2) n o
zipValueStructures (VTypeRef r l as1) (VTypeRef _ _ as2) = VTypeRef r l (zip as1 as2)
zipValueStructures (VArray m1 ds1) (VArray m2 ds2) = VArray (zipWithMaybe (,) m1 m2) (zip ds1 ds2)
zipValueStructures (VSingleton s i) (VSingleton _ _) = VSingleton s i
zipValueStructures (VExternal l) (VExternal _) = VExternal l
zipValueStructures (VIntLit l) (VIntLit _) = VIntLit l
zipValueStructures (VNameLit l) (VNameLit _) = VNameLit l
zipValueStructures (VEnumMem l) (VEnumMem _) = VEnumMem l
zipValueStructures VVarArg VVarArg = VVarArg
zipValueStructures _ _ = error "zipValueStructures: mismatch"

zipFT :: FullTemplateF tid a -> FullTemplateF tid b -> FullTemplateF tid (a, b)
zipFT (FT tid i1) (FT _ i2) = FT tid (zipWithMaybe (,) i1 i2)

zipWithMaybe :: (a -> b -> c) -> Maybe a -> Maybe b -> Maybe c
zipWithMaybe f (Just a) (Just b) = Just (f a b)
zipWithMaybe _ _ _               = Nothing

flipPol :: Polarity -> Polarity
flipPol PJoin = PMeet
flipPol PMeet = PJoin


zero :: ProductState -> RigidNodeF tid (a, b, ProductState)
zero ps = case psPolarity ps of
    PJoin -> RSpecial SConflict
    PMeet -> RSpecial SUnconstrained

getTargetState :: (Eq a, Show a)
               => ProductState
               -> (a -> Maybe (RigidNodeF tid a)) -- ^ Rigid node lookup
               -> (a, a)
               -> Constness -> Constness
               -> a -> a
               -> (ProductState, Bool)
getTargetState ProductState{..} lookupNode (bot, top) cL cR tL tR =
    let resC = case psPolarity of { PJoin -> max cL cR; PMeet -> min cL cR }
        resC' = if psForceConst then QConst' else resC

        nextL_base = stepQual psQualL (resC' == QConst')
        nextR_base = stepQual psQualR (resC' == QConst')
        invariance_base = not (allowCovariance nextL_base && allowCovariance nextR_base)

        isIdentity t = t == bot || case lookupNode t of
            Just (RValue (VPointer t' _ _) _ _)    -> isIdentity t'
            Just (RValue (VArray (Just t') _) _ _) -> isIdentity t'
            Just (RValue (VArray Nothing _) _ _)   -> True
            Just (RSpecial SUnconstrained)         -> True
            _                                      -> False

        isTop t = t == top || case lookupNode t of
            Just (RSpecial SConflict) -> True
            _                         -> False

        isIdL = case psPolarity of { PJoin -> isIdentity tL; PMeet -> isTop tL }
        isIdR = case psPolarity of { PJoin -> isIdentity tR; PMeet -> isTop tR }

        -- Sound LUB discovery: force const only if targets differ and we are in an invariant context.
        -- Do not force if one side is the lattice identity.
        forceConst = psPolarity == PJoin && not (tL == tR) && invariance_base && not (isIdL || isIdR)

        nextL = if forceConst then stepQual psQualL True else nextL_base
        nextR = if forceConst then stepQual psQualR True else nextR_base

        canJoin = psPolarity == PMeet || tL == tR || allowCovariance nextL || allowCovariance nextR || forceConst || isIdL || isIdR
    in (ProductState psPolarity nextL nextR forceConst, canJoin)

