{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Util.UnionSolve(
    C(),
    solve,
    Fixable(..),
    Topped(..),
    Result(..),
    cAnnotate,
    islte,isgte,equals,
    (@<=),(@>=),(@=),(@<=@),(@>=@),(@=@)
    ) where

import Control.Monad(unless, forM_)
import Data.List(intersperse)
import Data.Monoid
import qualified Data.Foldable as S
import qualified Data.Map as Map
import qualified Data.Sequence as S
import qualified Data.Set as Set

import Util.UnionFind as UF

-- simple constraint solver based on ideas from 'Once upon a polymorphic type' paper.

class Fixable a where
    -- determine if we are at the top or bottom of the lattice, we can
    -- solidify bounds if we know we are at an endpoint.
    isBottom :: a -> Bool
    isTop :: a -> Bool
    -- lattice operators
    join :: a -> a -> a
    meet :: a -> a -> a
    eq :: a -> a -> Bool
    lte :: a -> a -> Bool
    -- used for debugging
    showFixable :: a -> String
    -- default methods
    showFixable x | isBottom x = "B"
                  | isTop x = "T"
                  | otherwise = "*"
    eq x y = lte x y && lte y x
    isBottom _ = False
    isTop _ = False

-- arguments are the lattice and the variable type
-- mappended together when used in a writer monad.
-- (C l v) represents a constraint (or set of constraints) that confine the
-- variables 'v' to within specific values of 'l'

newtype C l v = C (S.Seq (CL l v))
    deriving(Monoid)

data Op = OpLte | OpEq | OpGte

{-
flipOp OpLte = OpGte
flipOp OpGte = OpLte
flipOp OpEq = OpEq
-}

instance Show Op where
    show OpEq  = " = "
    show OpGte = " >= "
    show OpLte = " <= "

data CL l v = CV v Op v | CL v Op l | CLAnnotate String (CL l v)

cAnnotate :: String -> C l v -> C l v
cAnnotate s (C seq) = C (fmap (CLAnnotate s) seq)

instance (Show e,Show l) => Show (C l e) where
    showsPrec _ (C xs) = showString "" . foldr (.) id (intersperse (showString "\n") (map shows (S.toList xs))) . showString "\n"

instance (Show e,Show l) => Show (CL l e) where
    showsPrec _ x = case x of
        CV v1 op v2 -> shows v1 . shows op . shows v2
        CL v1 op v2 -> shows v1 . shows op . shows v2
        CLAnnotate _ c -> shows c

bool t f b = if b then t else f

-- operator constraits, the @ is on the side that takes a variable.
v @<= l = cL v OpLte l
v @>= l = cL v OpGte l
v @=  l = cL v OpEq l
v @<=@ l = cV v OpLte l
v @>=@ l = cV v OpGte l
v @=@  l = cV v OpEq l

cL x y z = C (S.singleton (CL x y z))
cV x y z = C (S.singleton (CV x y z))

-- basic constraints
islte,isgte,equals :: (Fixable l,Ord v) => Either v l -> Either v l -> C l v
islte (Left v1) (Left v2)   = C (S.singleton (CV v1 OpLte v2))
islte (Left v1) (Right v2)  = C (S.singleton (CL v1 OpLte v2))
islte (Right v1) (Left v2)  = C (S.singleton (CL v2 OpGte v1))
islte (Right l1) (Right l2) = bool mempty (error $ "invalid constraint: " ++ showFixable l1 ++ " <= " ++ showFixable l2) (l1 `lte` l2)

isgte (Left v1) (Left v2)   = C (S.singleton (CV v1 OpGte v2))
isgte (Left v1) (Right v2)  = C (S.singleton (CL v1 OpGte v2))
isgte (Right v1) (Left v2)  = C (S.singleton (CL v2 OpLte v1))
isgte (Right l1) (Right l2) = bool mempty (error $ "invalid constraint: " ++ showFixable l1 ++ " >= " ++ showFixable l2) (l2 `lte` l1)

equals (Left v1) (Left v2)   = C (S.singleton (CV v1 OpEq v2))
equals (Left v1) (Right v2)  = C (S.singleton (CL v1 OpEq v2))
equals (Right v1) (Left v2)  = C (S.singleton (CL v2 OpEq v1))
equals (Right l1) (Right l2) = bool mempty (error $ "invalid constraint: " ++ showFixable l1 ++ " = " ++ showFixable l2) (l1 `eq` l2)

-- a variable is either set to a value or bounded by other values
data R l a = R l |  Ri (Maybe l) (Set.Set (RS l a))  (Maybe l) (Set.Set (RS l a))
    deriving(Show)
type RS l a = Element (R l a) a

-- replace variables with UnionFind elements
prepareConstraints :: Ord v => C l v -> IO ([CL l (RS l v)], Map.Map v (RS l v))
prepareConstraints (C cseq) = f Map.empty (S.toList cseq) id [] where
    f m (c:cs) ar rs = do
        let h x mp = case Map.lookup x mp of
                Just v -> return (v,mp)
                Nothing -> do
                    v <- UF.new (Ri Nothing mempty Nothing mempty) x
                    return (v, Map.insert x v mp)
        case c of
            CL x op l -> do
                (x',m') <- h x m
                f m' cs id (ar (CL x' op l):rs)
            CV x op y -> do
                (x',m') <- h x m
                (y',m'') <- h y m'
                f m'' cs id (ar (CV x' op y'):rs)
            CLAnnotate s c -> f m (c:cs) (ar . CLAnnotate s) rs
    f m [] _ rs = return (rs,m)

check op x y = case op of
    OpEq -> x `eq` y
    OpLte -> x `lte` y
    OpGte -> y `lte` x

{-# NOINLINE solve #-}
solve :: (Fixable l, Show l, Show v, Ord v)
    => (String -> IO ())
    -> C l v
    -> IO (Map.Map v v,Map.Map v (Result l v))
solve putLog csp = do
    (pcs,varMap) <- prepareConstraints csp
    let procVar (CV x op y) = do
            xe <- UF.find x
            ye <- UF.find y
            doVar "" xe op ye
        procVar (CLAnnotate s CL {}) =  return ()
        procVar CL {} = return ()
        procVar (CLAnnotate s cr) =  putLog s >>  procVar cr
        doVar _ xe _ ye | xe == ye = return ()
        doVar lvl xe op ye = do
            putLog $ lvl ++ "Constraining: " ++ show (fromElement xe) ++ show op ++ show (fromElement ye)
            xw <- UF.getW xe
            yw <- UF.getW ye
            case (xw,yw) of
                (Ri xml xlb xmu xub,Ri yml ylb ymu yub) -> do
                    xub <- finds xub
                    xlb <- finds xlb
                    yub <- finds yub
                    ylb <- finds ylb
                    case op of
                        OpEq  ->  doEq lvl xe (Ri xml xlb xmu xub) ye (Ri yml ylb ymu yub)
                        OpLte -> doLte lvl xe (Ri xml xlb xmu xub) ye (Ri yml ylb ymu yub)
                        OpGte -> doLte lvl ye (Ri yml ylb ymu yub) xe (Ri xml xlb xmu xub)
                _ -> fail $ "UnionSolve: bad " ++  show (xw,yw)
        doEq lvl xe ~(Ri _ xlb _ xub) ye ~(Ri _ ylb _ yub) = do
            union const xe ye
            ne <- find xe
            nlb <- finds (xlb `Set.union` ylb)
            nub <- finds (yub `Set.union` xub)
            UF.putW ne (Ri Nothing nlb Nothing nub)
            checkRS lvl ne
        doLte lvl xe ~xw@(Ri xml xlb xmu xub) ye ~yw@(Ri yml ylb ymu yub) = do
            let done = UF.putW xe (Ri xml xlb xmu xub) >> UF.putW ye (Ri yml ylb ymu yub)
            if ye `Set.member` xub then done else do
            if xe `Set.member` ylb then done else do
            if ye `Set.member` xlb then doEq lvl xe xw ye yw else do
            if xe `Set.member` yub then doEq lvl xe xw ye yw else do
            UF.putW xe (Ri xml xlb xmu (Set.insert ye (xub `Set.union` yub)))
            UF.putW ye (Ri yml (Set.insert xe (ylb `Set.union` xlb)) ymu yub)
            checkRS lvl xe
            ye <- find ye
            checkRS lvl ye
        checkRS lvl ve = do
            Ri l lb h ub <- UF.getW ve
            lb <- finds lb
            ub <- finds ub
            UF.putW ve (Ri l (Set.delete ve lb) h (Set.delete ve ub))
            let equiv = lb `Set.intersection` ub
            forM_ (Set.toList equiv) $ doVar ('#':lvl) ve OpEq
        finds set = fmap Set.fromList $ mapM UF.find (Set.toList set)
    mapM_ procVar pcs

    let procLit (CL x op y) = do
            xe <- UF.find x
            doOp "" xe op y
        procLit (CLAnnotate s CV {}) =  return ()
        procLit CV {} = return ()
        procLit (CLAnnotate s cr) =  putLog s >>  procLit cr

        doOp lvl ve op l = do
            let doOp' ve op l = doOp ('-':lvl) ve op l
            putLog $ lvl ++ "Constraining: " ++ show (fromElement ve) ++ show op ++ show l
            vw <- getW ve
            case (op,vw) of
                (_,R c) | check op c l -> return ()
                        | otherwise -> fail $ "UnionSolve: constraint doesn't match (" ++ show c ++ show op ++ show l ++ ") when setting " ++ show (fromElement ve)
                (OpEq,Ri ml lb mu ub) | testBoundLT ml l && testBoundGT mu l -> do
                    updateW (const (R l)) ve
                    mapM_ (\v -> doOp' v OpLte l) (Set.toList lb)
                    mapM_ (\v -> doOp' v OpGte l) (Set.toList ub)
                (OpEq,_) | otherwise -> fail $ "UnionSolve: setValue " ++ show (fromElement ve,vw,l)
                (OpLte,Ri _ _ (Just n) _) | n `lte` l -> return ()
                (OpGte,Ri (Just n) _ _ _) | l `lte` n -> return ()
                (OpLte,Ri (Just n) _ _ _) | n `eq` l -> doOp' ve OpEq l
                (OpGte,Ri _ _ (Just n) _) | n `eq` l -> doOp' ve OpEq l
                (OpLte,Ri (Just n) _ _ _) | l `lte` n -> fail $ "UnionSolve: lower than lower bound  " ++ show (fromElement ve,vw,l,n)
                (OpGte,Ri _ _ (Just n) _) | n `lte` l -> fail $ "UnionSolve: higher than higher bound  " ++ show (fromElement ve,vw,l,n)
                (OpLte,Ri ml lb mu ub) -> do
                    let nv@(Just l') = mmeet (Just l) mu
                    doUpdate (Ri ml lb nv ub) ve
                    unless (nv `eq` mu) $
                        mapM_ (\v -> doOp' v OpLte l') (Set.toList lb)
                (OpGte,Ri ml lb mu ub) -> do
                    let nv@(Just l') = (mjoin (Just l) ml)
                    doUpdate (Ri nv lb mu ub) ve
                    unless (nv `eq` ml) $
                        mapM_ (\v -> doOp' v OpGte l') (Set.toList ub)
                -- _ -> fail $ "UnionSolve: bad " ++  show (fromElement ve,vw,op,l)
        testBoundLT Nothing _ = True
        testBoundLT (Just x) y = x `lte` y
        testBoundGT Nothing _ = True
        testBoundGT (Just x) y = y `lte` x
        checkRS (Ri (Just l) _ (Just u) _) xe | l `eq` u = do
            putLog $ "Boxed in value of " ++ show (fromElement xe) ++ " being set to " ++ show l
            doOp "&" xe OpEq l
        checkRS (Ri (Just l) _ (Just u) _) xe | u `lte` l = fail "checkRS: you crossed the streams"
        checkRS (Ri (Just l) _ _ _) xe  | isTop l = do
            putLog $ "Going up:   " ++ show (fromElement xe)
            doOp "&" xe OpEq l
        checkRS (Ri  _ _ (Just u) _) xe | isBottom u = do
            putLog $ "Going down: " ++ show (fromElement xe)
            doOp "&" xe OpEq u
        checkRS r xe = return ()
        doUpdate r xe = do
            updateW (const r) xe
            checkRS r xe
        mjoin Nothing b = b
        mjoin x Nothing = x
        mjoin (Just x) (Just y) = Just (join x y)
        mmeet Nothing b = b
        mmeet x Nothing = x
        mmeet (Just x) (Just y) = Just (meet x y)
    mapM_ procLit pcs
    rs <- flip mapM (Map.toList varMap) $ \ (a,e) -> do
        e <- find e
        w <- getW e
        rr <- case w of
            R v -> return (ResultJust (fromElement e) v)
            Ri ml lb mu ub -> do
                ub <- fmap (map fromElement . Set.toList) $ finds ub
                lb <- fmap (map fromElement . Set.toList) $ finds lb
                return (ResultBounded { resultRep = fromElement e, resultUB = mu, resultLB = ml, resultLBV = lb, resultUBV = ub })
        let aa = fromElement e
        return ((a,aa),(aa,rr))
    let (ma,mb) = unzip rs
    return (Map.fromList ma,Map.fromList mb)

-----------------------------------------------------------
-- The data type the results of the analysis are placed in.
-----------------------------------------------------------
data Result l a =
    ResultJust {
        resultRep :: a,
        resultValue :: l
    } |
    ResultBounded {
        resultRep :: a,
        resultLB :: Maybe l,
        resultUB :: Maybe l,
        resultLBV ::[a],
        resultUBV ::[a]
    }

instance (Show l, Show a) => Show (Result l a) where
    showsPrec _ x = (showResult x ++)

showResult (ResultJust a l) = show a ++ " = " ++ show l
showResult rb@ResultBounded {} = sb (resultLB rb) (resultLBV rb) ++ " <= " ++ show (resultRep rb) ++ " <= " ++ sb (resultUB rb) (resultUBV rb)  where
    sb Nothing n | null n = "_"
    sb (Just x) n | null n = show x
    sb Nothing n = show n
    sb (Just x) n = show x ++ show n

-------------------------------
-- useful instances for Fixable
-------------------------------

instance Ord n => Fixable (Set.Set n)  where
    isBottom = Set.null
    join a b = Set.union a b
    meet a b = Set.intersection a b
    lte a b = Set.isSubsetOf a b
    eq = (==)

instance Fixable Bool where
    isBottom x = not x
    isTop x = x
    join a b = a || b
    meet a b = a && b
    eq = (==)
    lte = (<=)

-- join is the maximum of integer values, as in this is the lattice of maximum, not the additive one.
instance Fixable Int where
    join a b = max a b
    meet a b = min a b
    lte = (<=)
    eq = (==)

instance (Fixable a,Fixable b) => Fixable (a,b) where
    isBottom (a,b) = isBottom a && isBottom b
    isTop (a,b) = isTop a && isTop b
    join (x,y) (x',y') = (join x x', join y y')
    meet (x,y) (x',y') = (meet x x', meet y y')
    lte (x,y) (x',y') = (lte x x' && lte y y')
    eq (x,y) (x',y') = (eq x x' && eq y y')

-- the maybe instance creates a new bottom of nothing. note that (Just bottom) is a distinct point.
instance Fixable a => Fixable (Maybe a) where
    isBottom Nothing = True
    isBottom _ = False
    isTop Nothing = False
    isTop (Just x) = isTop x
    join Nothing b = b
    join a Nothing = a
    join (Just a) (Just b) = Just (join a b)
    meet Nothing b = Nothing
    meet a Nothing = Nothing
    meet (Just a) (Just b) = Just (meet a b)
    lte Nothing _ = True
    lte _ Nothing = False
    lte (Just x) (Just y) = x `lte` y

-- the topped instance creates a new top of everything.
-- this is the opposite of the 'Maybe' instance
data Topped a = Only a | Top
    deriving(Eq,Ord,Show)

-- the maybe instance creates a new bottom of nothing. note that (Just bottom) is a distinct point.
instance Fixable a => Fixable (Topped a) where
    isBottom (Only x) = isBottom x
    isBottom Top = False
    isTop Top = True
    isTop _ = False
    meet Top b = b
    meet a Top = a
    meet (Only a) (Only b) = Only (meet a b)
    join Top b = Top
    join a Top = Top
    join (Only a) (Only b) = Only (join a b)
    eq Top Top = True
    eq (Only x) (Only y) = eq x y
    eq _ _ = False
    lte _ Top = True
    lte Top _ = False
    lte (Only x) (Only y) = x `lte` y