{-# LANGUAGE GeneralizedNewtypeDeriving #-} module Agda.Utils.Warshall where {- construct a graph from constraints x + n <= y becomes x ---(-n)---> y x <= n + y becomes x ---(+n)---> y the default edge (= no edge is) labelled with infinity building the graph involves keeping track of the node names. We do this in a finite map, assigning consecutive numbers to nodes. -} import Control.Applicative import Control.Monad.State import Data.Maybe -- fromJust import Data.Array import Data.List import Data.Map (Map) import qualified Data.Map as Map import Test.QuickCheck import Agda.Utils.TestHelpers import Agda.Syntax.Common import Agda.Utils.QuickCheck import Agda.Utils.SemiRing import Debug.Trace type Matrix a = Array (Int,Int) a -- assuming a square matrix warshall :: SemiRing a => Matrix a -> Matrix a warshall a0 = loop r a0 where b@((r,c),(r',c')) = bounds a0 -- assuming r == c and r' == c' loop k a | k <= r' = loop (k+1) (array b [ ((i,j), (a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j)))) | i <- [r..r'], j <- [c..c'] ]) | otherwise = a -- Warshall's algorithm on a graph represented as an adjacency list. type AdjList node edge = Map node [(node, edge)] warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge warshallG g = fromMatrix $ warshall m where nodes = zip (nub $ Map.keys g ++ map fst (concat $ Map.elems g)) [0..] len = length nodes b = ((0,0), (len - 1,len - 1)) edge i j = do es <- Map.lookup i g foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ] m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ] fromMatrix matrix = Map.fromList $ do (i, n) <- nodes let es = [ (fst (nodes !! m), e) | m <- [0..len - 1] , Just e <- [matrix ! (n, m)] ] return (i, es) -- edge weight in the graph, forming a semi ring data Weight = Finite Int | Infinite deriving (Eq) inc :: Weight -> Int -> Weight inc Infinite n = Infinite inc (Finite k) n = Finite (k + n) instance Show Weight where show (Finite i) = show i show Infinite = "." instance Ord Weight where a <= Infinite = True Infinite <= b = False Finite a <= Finite b = a <= b instance SemiRing Weight where oplus = min otimes Infinite _ = Infinite otimes _ Infinite = Infinite otimes (Finite a) (Finite b) = Finite (a + b) -- constraints --------------------------------------------------- -- nodes of the graph are either -- * flexible variables (with identifiers drawn from Int), -- * rigid variables (also identified by Ints), or -- * constants (like 0, infinity, or anything between) data Node = Rigid Rigid | Flex FlexId deriving (Eq, Ord) data Rigid = RConst Weight | RVar RigidId deriving (Eq, Ord, Show) type NodeId = Int type RigidId = Int type FlexId = Int type Scope = RigidId -> Bool -- which rigid variables a flex may be instatiated to instance Show Node where show (Flex i) = "?" ++ show i show (Rigid (RVar i)) = "v" ++ show i show (Rigid (RConst Infinite)) = "#" show (Rigid (RConst (Finite n))) = show n infinite (RConst Infinite) = True infinite _ = False -- isBelow r w r' -- checks, if r and r' are connected by w (meaning w not infinite) -- wether r + w <= r' -- precondition: not the same rigid variable isBelow :: Rigid -> Weight -> Rigid -> Bool isBelow _ Infinite _ = True isBelow _ n (RConst Infinite) = True -- isBelow (RConst Infinite) n (RConst (Finite _)) = False isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j isBelow _ _ _ = False -- rigid variables are not related -- a constraint is an edge in the graph data Constraint = NewFlex FlexId Scope | Arc Node Int Node -- Arc v1 k v2 at least one of v1,v2 is a VMeta (Flex), -- the other a VMeta or a VGen (Rigid) -- if k <= 0 this means $^(-k) v1 <= v2 -- otherwise v1 <= $^k v3 instance Show Constraint where show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")" show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2 | k < 0 = show v1 ++ "+" ++ show (-k) ++ "<=" ++ show v2 | otherwise = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k type Constraints = [Constraint] emptyConstraints = [] -- graph (matrix) ------------------------------------------------ data Graph = Graph { flexScope :: Map FlexId Scope -- scope for each flexible var , nodeMap :: Map Node NodeId -- node labels to node numbers , intMap :: Map NodeId Node -- node numbers to node labels , nextNode :: NodeId -- number of nodes (n) , graph :: NodeId -> NodeId -> Weight -- the edges (restrict to [0..n[) } -- the empty graph: no nodes, edges are all undefined (infinity weight) initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite) -- the Graph Monad, for constructing a graph iteratively type GM = State Graph addFlex :: FlexId -> Scope -> GM () addFlex x scope = do st <- get put $ st { flexScope = Map.insert x scope (flexScope st) } addNode (Flex x) return () -- i <- addNode n returns number of node n. if not present, it is added first addNode :: Node -> GM Int addNode n = do st <- get case Map.lookup n (nodeMap st) of Just i -> return i Nothing -> do let i = nextNode st put $ st { nodeMap = Map.insert n i (nodeMap st) , intMap = Map.insert i n (intMap st) , nextNode = i + 1 } return i -- addEdge n1 k n2 -- improves the weight of egde n1->n2 to be at most k -- also adds nodes if not yet present addEdge :: Node -> Int -> Node -> GM () addEdge n1 k n2 = do i1 <- addNode n1 i2 <- addNode n2 st <- get let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y else graph st x y put $ st { graph = graph' } addConstraint :: Constraint -> GM () addConstraint (NewFlex x scope) = addFlex x scope addConstraint (Arc n1 k n2) = addEdge n1 k n2 buildGraph :: Constraints -> Graph buildGraph cs = execState (mapM_ addConstraint cs) initGraph mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight mkMatrix n g = array ((0,0),(n-1,n-1)) [ ((i,j), g i j) | i <- [0..n-1], j <- [0..n-1]] -- displaying matrices with row and column labels -------------------- -- a matrix with row descriptions in b and column descriptions in c data LegendMatrix a b c = LegendMatrix { matrix :: Matrix a , rowdescr :: Int -> b , coldescr :: Int -> c } instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where show (LegendMatrix m rd cd) = -- first show column description let ((r,c),(r',c')) = bounds m in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++ -- then output rows foldr (\ i s -> "\n" ++ show (rd i) ++ foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t) (s) [c .. c']) "" [r .. r'] -- solving the constraints ------------------------------------------- -- a solution assigns to each flexible variable a size expression -- which is either a constant or a v + n for a rigid variable v type Solution = Map Int SizeExpr emptySolution = Map.empty extendSolution subst k v = Map.insert k v subst data SizeExpr = SizeVar Int Int -- e.g. x + 5 | SizeConst Weight -- a number or infinity instance Show SizeExpr where show (SizeVar n 0) = show (Rigid (RVar n)) show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k show (SizeConst w) = show w -- sizeRigid r n returns the size expression corresponding to r + n sizeRigid :: Rigid -> Int -> SizeExpr sizeRigid (RConst k) n = SizeConst (inc k n) sizeRigid (RVar i) n = SizeVar i n {- apply :: SizeExpr -> Solution -> SizeExpr apply e@(SizeExpr (Rigid _) _) phi = e apply e@(SizeExpr (Flex x) i) phi = case Map.lookup x phi of Nothing -> e Just (SizeExpr v j) -> SizeExpr v (i + j) after :: Solution -> Solution -> Solution after psi phi = Map.map (\ e -> e `apply` phi) psi -} {- solve :: Constraints -> Maybe Solution solve cs = if any (\ x -> x < Finite 0) d then Nothing else Map. where gr = buildGraph cs n = nextNode gr m = mkMatrix n (graph gr) m' = warshall m d = [ m!(i,i) | i <- [0 .. (n-1)] ] ns = keys (nodeMap gr) -} {- compute solution a solution CANNOT exist if v < v for a rigid variable v v <= v' for rigid variables v,v' x < v for a flexible variable x and a rigid variable v thus, for each flexible x, only one of the following cases is possible r+n <= x+m <= infty for a unique rigid r (meaning r --(m-n)--> x) x <= r+n for a unique rigid r (meaning x --(n)--> r) we are looking for the least values for flexible variables that solve the constraints. Algorithm while flexible variables and rigid rows left find a rigid variable row i for all flexible columns j if i --n--> j with n<=0 (meaning i+n <= j) then j = i + n while flexible variables j left search the row j for entry i if j --n--> i with n >= 0 (meaning j <= i + n) then j = i + n -} solve :: Constraints -> Maybe Solution solve cs = -- trace (show cs) $ -- trace (show lm0) $ -- trace (show lm) $ -- trace (show d) $ let solution = if solvable then loop1 flexs rigids emptySolution else Nothing in -- trace (show solution) $ solution where -- compute the graph and its transitive closure m gr = buildGraph cs n = nextNode gr -- number of nodes m0 = mkMatrix n (graph gr) m = warshall m0 -- tracing only: build output version of transitive graph legend i = fromJust $ Map.lookup i (intMap gr) -- trace only lm0 = LegendMatrix m0 legend legend -- trace only lm = LegendMatrix m legend legend -- trace only -- compute the sets of flexible and rigid node numbers ns = Map.keys (nodeMap gr) -- a set of flexible variables flexs = foldl (\ l k -> case k of (Flex i) -> i : l (Rigid _) -> l) [] ns -- a set of rigid variables rigids = foldl (\ l k -> case k of (Flex _) -> l (Rigid i) -> i : l) [] ns -- rigid matrix indices rInds = foldl (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr) in i : l) [] rigids -- check whether there is a solution -- d = [ m!(i,i) | i <- [0 .. (n-1)] ] -- diagonal -- a rigid variable might not be less than it self, so no -.. on the -- rigid part of the diagonal solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] && -- a rigid variable might not be bounded below by infinity or -- bounded above by a constant -- it might not be related to another rigid variable all (\ (r, r') -> r == r' || let Just row = (Map.lookup (Rigid r) (nodeMap gr)) Just col = (Map.lookup (Rigid r') (nodeMap gr)) edge = m!(row,col) in isBelow r edge r' ) [ (r,r') | r <- rigids, r' <- rigids ] && -- a flexible variable might not be strictly below a rigid variable all (\ (x, v) -> let Just row = (Map.lookup (Flex x) (nodeMap gr)) Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr)) edge = m!(row,col) in edge >= Finite 0) [ (x,v) | x <- flexs, (RVar v) <- rigids ] inScope :: FlexId -> Rigid -> Bool inScope x (RConst _) = True inScope x (RVar v) = scope v where Just scope = Map.lookup x (flexScope gr) {- loop1 while flexible variables and rigid rows left find a rigid variable row i for all flexible columns j if i --n--> j with n<=0 (meaning i + n <= j) then j = i + n -} loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution loop1 [] rgds subst = Just subst loop1 flxs [] subst = loop2 flxs subst loop1 flxs (r:rgds) subst = let row = fromJust $ Map.lookup (Rigid r) (nodeMap gr) (flxs',subst') = foldl (\ (flx,sub) f -> let col = fromJust $ Map.lookup (Flex f) (nodeMap gr) in case (inScope f r, m!(row,col)) of -- Finite z | z <= 0 -> (True, Finite z) -> let trunc z | z >= 0 = 0 | otherwise = -z in (flx, extendSolution sub f (sizeRigid r (trunc z))) _ -> (f : flx, sub) ) ([], subst) flxs in loop1 flxs' rgds subst' {- loop2 while flexible variables j left search the row j for entry i if j --n--> i with n >= 0 (meaning j <= i + n) then j = i -} loop2 :: [FlexId] -> Solution -> Maybe Solution loop2 [] subst = Just subst loop2 (f:flxs) subst = loop3 0 subst where row = fromJust $ Map.lookup (Flex f) (nodeMap gr) loop3 col subst | col >= n = -- default to infinity loop2 flxs (extendSolution subst f (SizeConst Infinite)) loop3 col subst = case Map.lookup col (intMap gr) of Just (Rigid r) | not (infinite r) -> case (inScope f r, m!(row,col)) of (True, Finite z) | z >= 0 -> loop2 flxs (extendSolution subst f (sizeRigid r z)) (_, Infinite) -> loop3 (col+1) subst _ -> -- trace ("unusable rigid: " ++ show r ++ " for flex " ++ show f) Nothing -- NOT: loop3 (col+1) subst _ -> loop3 (col+1) subst -- Testing ---------------------------------------------------------------- genGraph :: Ord node => Float -> Gen edge -> [node] -> Gen (AdjList node edge) genGraph density edge nodes = do Map.fromList . concat <$> mapM neighbours nodes where k = round (100 * density) neighbours n = do ns <- concat <$> mapM neighbour nodes case ns of [] -> elements [[(n, [])], []] _ -> return [(n, ns)] neighbour n = frequency [ (k, do e <- edge ns <- neighbour n return ((n, e):ns)) , (100 - k, return []) ] newtype Distance = Dist Nat deriving (Eq, Ord, Num, Integral, Show, Enum, Real) instance SemiRing Distance where oplus (Dist a) (Dist b) = Dist (min a b) otimes (Dist a) (Dist b) = Dist (a + b) genGraph_ :: Nat -> Gen (AdjList Nat Distance) genGraph_ n = genGraph 0.2 (Dist <$> natural) [0..n - 1] lookupEdge :: Ord n => n -> n -> AdjList n e -> Maybe e lookupEdge i j g = lookup j =<< Map.lookup i g edges :: Ord n => AdjList n e -> [(n,n,e)] edges g = do (i, ns) <- Map.toList g (j, e) <- ns return (i, j, e) -- | Check that no edges get longer when completing a graph. prop_smaller n' = forAll (genGraph_ n) $ \g -> let g' = warshallG g in and [ lookupEdge i j g' =< e | (i, j, e) <- edges g ] where n = abs (div n' 2) Nothing =< _ = False Just x =< y = x <= y newEdge i j e = Map.insertWith (++) i [(j, e)] genPath :: Nat -> Nat -> Nat -> AdjList Nat Distance -> Gen (AdjList Nat Distance) genPath n i j g = do es <- listOf $ (,) <$> node <*> edge v <- edge return $ addPath i (es ++ [(j, v)]) g where edge = Dist <$> natural node = choose (0, n - 1) addPath _ [] g = g addPath i ((j, v):es) g = newEdge i j v $ addPath j es g -- | Check that all transitive edges are added. prop_path n' = forAll (genGraph_ n) $ \g -> forAll (replicateM 2 $ choose (0, n - 1)) $ \[i, j] -> forAll (genPath n i j g) $ \g' -> isJust (lookupEdge i j $ warshallG g') where n = abs (div n' 2) + 1 mapNodes :: (Ord node, Ord node') => (node -> node') -> AdjList node edge -> AdjList node' edge mapNodes f = Map.map f' . Map.mapKeys f where f' es = [ (f n, e) | (n,e) <- es ] -- | Check that no edges are added between components. prop_disjoint n' = forAll (replicateM 2 $ genGraph_ n) $ \[g1, g2] -> let g = Map.union (mapNodes Left g1) (mapNodes Right g2) g' = warshallG g in all disjoint (Map.assocs g') where n = abs (div n' 3) disjoint (Left i, es) = all (isLeft . fst) es disjoint (Right i, es) = all (isRight . fst) es isLeft = either (const True) (const False) isRight = not . isLeft prop_stable n' = forAll (genGraph_ n) $ \g -> let g' = warshallG g in g' =~= warshallG g' where n = abs (div n' 2) g =~= g' = sort (edges g) == sort (edges g') tests :: IO Bool tests = runTests "Agda.Utils.Warshall" [ quickCheck' prop_smaller , quickCheck' prop_path , quickCheck' prop_disjoint , quickCheck' prop_stable ]