-- | This module contains the code for generating "tags" for constraints
--   based on their source, i.e. the top-level binders under which the
--   constraint was generated. These tags are used by fixpoint to
--   prioritize constraints by the "source-level" function.

{-# LANGUAGE TupleSections #-}

module Language.Haskell.Liquid.UX.CTags (
    -- * Type for constraint tags
    TagKey, TagEnv

    -- * Default tag value
  , defaultTag

    -- * Constructing @TagEnv@
  , makeTagEnv

    -- * Accessing @TagEnv@
  , getTag
  , memTagEnv

) where

import Var
import CoreSyn
import Prelude hiding (error)

import qualified Data.HashSet           as S
import qualified Data.HashMap.Strict    as M
import qualified Data.Graph             as G

import Language.Fixpoint.Types          (Tag)
import Language.Haskell.Liquid.Types.Visitors (freeVars)
import Language.Haskell.Liquid.Types.PrettyPrint ()
import Language.Fixpoint.Misc     (mapSnd)

-- | The @TagKey@ is the top-level binder, and @Tag@ is a singleton Int list

type TagKey = Var
type TagEnv = M.HashMap TagKey Tag

-- TODO: use the "callgraph" SCC to do this numbering.

defaultTag :: Tag
defaultTag :: Tag
defaultTag = [Int
0]

memTagEnv :: TagKey -> TagEnv -> Bool
memTagEnv :: TagKey -> TagEnv -> Bool
memTagEnv = TagKey -> TagEnv -> Bool
forall k a. (Eq k, Hashable k) => k -> HashMap k a -> Bool
M.member

makeTagEnv :: [CoreBind] -> TagEnv
makeTagEnv :: [CoreBind] -> TagEnv
makeTagEnv = {- tracepp "TAGENV" . -} (Int -> Tag) -> HashMap TagKey Int -> TagEnv
forall v1 v2 k. (v1 -> v2) -> HashMap k v1 -> HashMap k v2
M.map (Int -> Tag -> Tag
forall a. a -> [a] -> [a]
:[]) (HashMap TagKey Int -> TagEnv)
-> ([CoreBind] -> HashMap TagKey Int) -> [CoreBind] -> TagEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CallGraph -> HashMap TagKey Int
callGraphRanks (CallGraph -> HashMap TagKey Int)
-> ([CoreBind] -> CallGraph) -> [CoreBind] -> HashMap TagKey Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [CoreBind] -> CallGraph
makeCallGraph

-- makeTagEnv = M.fromList . (`zip` (map (:[]) [1..])). L.sort . map fst . concatMap bindEqns

getTag :: TagKey -> TagEnv -> Tag
getTag :: TagKey -> TagEnv -> Tag
getTag = Tag -> TagKey -> TagEnv -> Tag
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault Tag
defaultTag

------------------------------------------------------------------------------------------------------

type CallGraph = [(Var, [Var])] -- caller-callee pairs

callGraphRanks :: CallGraph -> M.HashMap Var Int
-- callGraphRanks cg = traceShow ("CallGraph Ranks: " ++ show cg) $ callGraphRanks' cg

callGraphRanks :: CallGraph -> HashMap TagKey Int
callGraphRanks  = [(TagKey, Int)] -> HashMap TagKey Int
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList ([(TagKey, Int)] -> HashMap TagKey Int)
-> (CallGraph -> [(TagKey, Int)])
-> CallGraph
-> HashMap TagKey Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(TagKey, Int)]] -> [(TagKey, Int)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(TagKey, Int)]] -> [(TagKey, Int)])
-> (CallGraph -> [[(TagKey, Int)]]) -> CallGraph -> [(TagKey, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SCC TagKey] -> [[(TagKey, Int)]]
forall t. [SCC t] -> [[(t, Int)]]
index ([SCC TagKey] -> [[(TagKey, Int)]])
-> (CallGraph -> [SCC TagKey]) -> CallGraph -> [[(TagKey, Int)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CallGraph -> [SCC TagKey]
forall key. Ord key => [(key, [key])] -> [SCC key]
mkScc
  where mkScc :: [(key, [key])] -> [SCC key]
mkScc [(key, [key])]
cg = [(key, key, [key])] -> [SCC key]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
G.stronglyConnComp [(key
u, key
u, [key]
vs) | (key
u, [key]
vs) <- [(key, [key])]
cg]
        index :: [SCC t] -> [[(t, Int)]]
index    = (Int -> SCC t -> [(t, Int)]) -> Tag -> [SCC t] -> [[(t, Int)]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> (t -> (t, Int)) -> [t] -> [(t, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (, Int
i) ([t] -> [(t, Int)]) -> (SCC t -> [t]) -> SCC t -> [(t, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC t -> [t]
forall vertex. SCC vertex -> [vertex]
G.flattenSCC) [Int
1..]

makeCallGraph :: [CoreBind] -> CallGraph
makeCallGraph :: [CoreBind] -> CallGraph
makeCallGraph [CoreBind]
cbs = (Expr TagKey -> [TagKey])
-> (TagKey, Expr TagKey) -> (TagKey, [TagKey])
forall b c a. (b -> c) -> (a, b) -> (a, c)
mapSnd Expr TagKey -> [TagKey]
calls ((TagKey, Expr TagKey) -> (TagKey, [TagKey]))
-> [(TagKey, Expr TagKey)] -> CallGraph
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` [(TagKey, Expr TagKey)]
xes
  where xes :: [(TagKey, Expr TagKey)]
xes       = (CoreBind -> [(TagKey, Expr TagKey)])
-> [CoreBind] -> [(TagKey, Expr TagKey)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CoreBind -> [(TagKey, Expr TagKey)]
forall t. Bind t -> [(t, Expr t)]
bindEqns [CoreBind]
cbs
        xs :: HashSet TagKey
xs        = [TagKey] -> HashSet TagKey
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([TagKey] -> HashSet TagKey) -> [TagKey] -> HashSet TagKey
forall a b. (a -> b) -> a -> b
$ ((TagKey, Expr TagKey) -> TagKey)
-> [(TagKey, Expr TagKey)] -> [TagKey]
forall a b. (a -> b) -> [a] -> [b]
map (TagKey, Expr TagKey) -> TagKey
forall a b. (a, b) -> a
fst [(TagKey, Expr TagKey)]
xes
        calls :: Expr TagKey -> [TagKey]
calls     = (TagKey -> Bool) -> [TagKey] -> [TagKey]
forall a. (a -> Bool) -> [a] -> [a]
filter (TagKey -> HashSet TagKey -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`S.member` HashSet TagKey
xs) ([TagKey] -> [TagKey])
-> (Expr TagKey -> [TagKey]) -> Expr TagKey -> [TagKey]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashSet TagKey -> Expr TagKey -> [TagKey]
forall a. CBVisitable a => HashSet TagKey -> a -> [TagKey]
freeVars HashSet TagKey
forall a. HashSet a
S.empty

bindEqns :: Bind t -> [(t, Expr t)]
bindEqns :: Bind t -> [(t, Expr t)]
bindEqns (NonRec t
x Expr t
e) = [(t
x, Expr t
e)]
bindEqns (Rec [(t, Expr t)]
xes)    = [(t, Expr t)]
xes