{-# LANGUAGE GeneralizedNewtypeDeriving #-} {- | Construct a graph from constraints @ x + n <= y becomes x ---(-n)---> y x <= n + y becomes x ---(+n)---> y @ the default edge (= no edge) is labelled with infinity. Building the graph involves keeping track of the node names. We do this in a finite map, assigning consecutive numbers to nodes. -} module Agda.Utils.Warshall where import Control.Applicative import Control.Monad.State import Data.Maybe import Data.Array import Data.List import Data.Map (Map) import qualified Data.Map as Map import Test.QuickCheck import Agda.Syntax.Common (Nat) import Agda.Utils.TestHelpers import Agda.Utils.SemiRing type Matrix a = Array (Int,Int) a -- assuming a square matrix warshall :: SemiRing a => Matrix a -> Matrix a warshall a0 = loop r a0 where b@((r,c),(r',c')) = bounds a0 -- assuming r == c and r' == c' loop k a | k <= r' = loop (k+1) (array b [ ((i,j), (a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j)))) | i <- [r..r'], j <- [c..c'] ]) | otherwise = a type AdjList node edge = Map node [(node, edge)] -- | Warshall's algorithm on a graph represented as an adjacency list. warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge warshallG g = fromMatrix $ warshall m where nodes = zip (nub $ Map.keys g ++ map fst (concat $ Map.elems g)) [0..] len = length nodes b = ((0,0), (len - 1,len - 1)) edge i j = do es <- Map.lookup i g foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ] m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ] fromMatrix matrix = Map.fromList $ do (i, n) <- nodes let es = [ (fst (nodes !! m), e) | m <- [0..len - 1] , Just e <- [matrix ! (n, m)] ] return (i, es) -- | Edge weight in the graph, forming a semi ring. data Weight = Finite Int | Infinite deriving (Eq) inc :: Weight -> Int -> Weight inc Infinite n = Infinite inc (Finite k) n = Finite (k + n) instance Show Weight where show (Finite i) = show i show Infinite = "." instance Ord Weight where a <= Infinite = True Infinite <= b = False Finite a <= Finite b = a <= b instance SemiRing Weight where oplus = min otimes Infinite _ = Infinite otimes _ Infinite = Infinite otimes (Finite a) (Finite b) = Finite (a + b) -- constraints --------------------------------------------------- -- | Nodes of the graph are either -- - flexible variables (with identifiers drawn from @Int@), -- - rigid variables (also identified by @Int@s), or -- - constants (like 0, infinity, or anything between). data Node = Rigid Rigid | Flex FlexId deriving (Eq, Ord) data Rigid = RConst Weight | RVar RigidId deriving (Eq, Ord, Show) type NodeId = Int type RigidId = Int type FlexId = Int type Scope = RigidId -> Bool -- ^ Which rigid variables a flex may be instatiated to. instance Show Node where show (Flex i) = "?" ++ show i show (Rigid (RVar i)) = "v" ++ show i show (Rigid (RConst Infinite)) = "#" show (Rigid (RConst (Finite n))) = show n infinite :: Rigid -> Bool infinite (RConst Infinite) = True infinite _ = False -- | @isBelow r w r'@ -- checks, if @r@ and @r'@ are connected by @w@ (meaning @w@ not infinite), -- whether @r + w <= r'@. -- Precondition: not the same rigid variable. isBelow :: Rigid -> Weight -> Rigid -> Bool isBelow _ Infinite _ = True isBelow _ n (RConst Infinite) = True isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j isBelow _ _ _ = False -- rigid variables are not related -- | A constraint is an edge in the graph. data Constraint = NewFlex FlexId Scope | Arc Node Int Node -- ^ For @Arc v1 k v2@ at least one of @v1@ or @v2@ is a @MetaV@ (Flex), -- the other a @MetaV@ or a @Var@ (Rigid). -- If @k <= 0@ this means @suc^(-k) v1 <= v2@ -- otherwise @v1 <= suc^k v3@. instance Show Constraint where show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")" show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2 | k < 0 = show v1 ++ "+" ++ show (-k) ++ "<=" ++ show v2 | otherwise = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k type Constraints = [Constraint] emptyConstraints :: Constraints emptyConstraints = [] -- graph (matrix) ------------------------------------------------ data Graph = Graph { flexScope :: Map FlexId Scope -- ^ Scope for each flexible var. , nodeMap :: Map Node NodeId -- ^ Node labels to node numbers. , intMap :: Map NodeId Node -- ^ Node numbers to node labels. , nextNode :: NodeId -- ^ Number of nodes @n@. , graph :: NodeId -> NodeId -> Weight -- ^ The edges (restrict to @[0..n[@). } -- | The empty graph: no nodes, edges are all undefined (infinity weight). initGraph :: Graph initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite) -- | The Graph Monad, for constructing a graph iteratively. type GM = State Graph -- | Add a size meta node. addFlex :: FlexId -> Scope -> GM () addFlex x scope = do modify $ \ st -> st { flexScope = Map.insert x scope (flexScope st) } _ <- addNode (Flex x) return () -- | Lookup identifier of a node. -- If not present, it is added first. addNode :: Node -> GM Int addNode n = do st <- get case Map.lookup n (nodeMap st) of Just i -> return i Nothing -> do let i = nextNode st put $ st { nodeMap = Map.insert n i (nodeMap st) , intMap = Map.insert i n (intMap st) , nextNode = i + 1 } return i -- | @addEdge n1 k n2@ -- improves the weight of egde @n1->n2@ to be at most @k@. -- Also adds nodes if not yet present. addEdge :: Node -> Int -> Node -> GM () addEdge n1 k n2 = do i1 <- addNode n1 i2 <- addNode n2 st <- get let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y else graph st x y put $ st { graph = graph' } addConstraint :: Constraint -> GM () addConstraint (NewFlex x scope) = addFlex x scope addConstraint (Arc n1 k n2) = addEdge n1 k n2 buildGraph :: Constraints -> Graph buildGraph cs = execState (mapM_ addConstraint cs) initGraph mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight mkMatrix n g = array ((0,0),(n-1,n-1)) [ ((i,j), g i j) | i <- [0..n-1], j <- [0..n-1]] -- displaying matrices with row and column labels -------------------- -- | A matrix with row descriptions in @b@ and column descriptions in @c@. data LegendMatrix a b c = LegendMatrix { matrix :: Matrix a , rowdescr :: Int -> b , coldescr :: Int -> c } instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where show (LegendMatrix m rd cd) = -- first show column description let ((r,c),(r',c')) = bounds m in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++ -- then output rows foldr (\ i s -> "\n" ++ show (rd i) ++ foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t) (s) [c .. c']) "" [r .. r'] -- solving the constraints ------------------------------------------- -- | A solution assigns to each flexible variable a size expression -- which is either a constant or a @v + n@ for a rigid variable @v@. type Solution = Map Int SizeExpr emptySolution :: Solution emptySolution = Map.empty extendSolution :: Solution -> Int -> SizeExpr -> Solution extendSolution subst k v = Map.insert k v subst data SizeExpr = SizeVar RigidId Int -- ^ e.g. x + 5 | SizeConst Weight -- ^ a number or infinity instance Show SizeExpr where show (SizeVar n 0) = show (Rigid (RVar n)) show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k show (SizeConst w) = show w -- | @sizeRigid r n@ returns the size expression corresponding to @r + n@ sizeRigid :: Rigid -> Int -> SizeExpr sizeRigid (RConst k) n = SizeConst (inc k n) sizeRigid (RVar i) n = SizeVar i n {- apply :: SizeExpr -> Solution -> SizeExpr apply e@(SizeExpr (Rigid _) _) phi = e apply e@(SizeExpr (Flex x) i) phi = case Map.lookup x phi of Nothing -> e Just (SizeExpr v j) -> SizeExpr v (i + j) after :: Solution -> Solution -> Solution after psi phi = Map.map (\ e -> e `apply` phi) psi -} {- compute solution a solution CANNOT exist if v < v for a rigid variable v -- Andreas, 2012-09-19 OUTDATED are: -- v <= v' for rigid variables v,v' -- x < v for a flexible variable x and a rigid variable v thus, for each flexible x, only one of the following cases is possible r+n <= x+m <= infty for a unique rigid r (meaning r --(m-n)--> x) x <= r+n for a unique rigid r (meaning x --(n)--> r) we are looking for the least values for flexible variables that solve the constraints. Algorithm while flexible variables and rigid rows left find a rigid variable row i for all flexible columns j if i --n--> j with n<=0 (meaning i+n <= j) then j = i + n while flexible variables j left search the row j for entry i if j --n--> i with n >= 0 (meaning j <= i + n) then j = i + n -} solve :: Constraints -> Maybe Solution solve cs = -- trace (show cs) $ -- trace (show lm0) $ -- trace (show lm) $ -- trace (show d) $ let solution = if solvable then loop1 flexs rigids emptySolution else Nothing in -- trace (show solution) $ solution where -- compute the graph and its transitive closure m gr = buildGraph cs n = nextNode gr -- number of nodes m0 = mkMatrix n (graph gr) m = warshall m0 -- tracing only: build output version of transitive graph legend i = fromJust $ Map.lookup i (intMap gr) -- trace only lm0 = LegendMatrix m0 legend legend -- trace only lm = LegendMatrix m legend legend -- trace only -- compute the sets of flexible and rigid node numbers ns = Map.keys (nodeMap gr) -- a set of flexible variables flexs = foldl' (\ l k -> case k of (Flex i) -> i : l (Rigid _) -> l) [] ns -- a set of rigid variables rigids = foldl' (\ l k -> case k of (Flex _) -> l (Rigid i) -> i : l) [] ns -- rigid matrix indices rInds = foldl' (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr) in i : l) [] rigids -- check whether there is a solution -- d = [ m!(i,i) | i <- [0 .. (n-1)] ] -- diagonal -- a rigid variable might not be less than it self, so no -.. on the -- rigid part of the diagonal solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] && True {- Andreas, 2012-09-19 We now can have constraints between rigid variables, like i < j. Thus we skip the following two test. However, a solution must be checked for consistency with the constraints on rigid vars. -- a rigid variable might not be bounded below by infinity or -- bounded above by a constant -- it might not be related to another rigid variable all (\ (r, r') -> r == r' || let Just row = (Map.lookup (Rigid r) (nodeMap gr)) Just col = (Map.lookup (Rigid r') (nodeMap gr)) edge = m!(row,col) in isBelow r edge r' ) [ (r,r') | r <- rigids, r' <- rigids ] && -- a flexible variable might not be strictly below a rigid variable all (\ (x, v) -> let Just row = (Map.lookup (Flex x) (nodeMap gr)) Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr)) edge = m!(row,col) in edge >= Finite 0) [ (x,v) | x <- flexs, (RVar v) <- rigids ] -} inScope :: FlexId -> Rigid -> Bool inScope x (RConst _) = True inScope x (RVar v) = scope v where Just scope = Map.lookup x (flexScope gr) {- loop1 while flexible variables and rigid rows left find a rigid variable row i for all flexible columns j if i --n--> j with n<=0 (meaning i + n <= j) then j = i + n -} loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution loop1 [] rgds subst = Just subst loop1 flxs [] subst = loop2 flxs subst loop1 flxs (r:rgds) subst = let row = fromJust $ Map.lookup (Rigid r) (nodeMap gr) (flxs',subst') = foldl' (\ (flx,sub) f -> let col = fromJust $ Map.lookup (Flex f) (nodeMap gr) in case (inScope f r, m!(row,col)) of -- Finite z | z <= 0 -> (True, Finite z) -> let trunc z | z >= 0 = 0 | otherwise = -z in (flx, extendSolution sub f (sizeRigid r (trunc z))) _ -> (f : flx, sub) ) ([], subst) flxs in loop1 flxs' rgds subst' {- loop2 while flexible variables j left search the row j for entry i if j --n--> i with n >= 0 (meaning j <= i + n) then j = i -} loop2 :: [FlexId] -> Solution -> Maybe Solution loop2 [] subst = Just subst loop2 (f:flxs) subst = loop3 0 subst where row = fromJust $ Map.lookup (Flex f) (nodeMap gr) loop3 col subst | col >= n = -- default to infinity loop2 flxs (extendSolution subst f (SizeConst Infinite)) loop3 col subst = case Map.lookup col (intMap gr) of Just (Rigid r) | not (infinite r) -> case (inScope f r, m!(row,col)) of (True, Finite z) | z >= 0 -> loop2 flxs (extendSolution subst f (sizeRigid r z)) (_, Infinite) -> loop3 (col+1) subst _ -> -- trace ("unusable rigid: " ++ show r ++ " for flex " ++ show f) Nothing -- NOT: loop3 (col+1) subst _ -> loop3 (col+1) subst -- Testing ---------------------------------------------------------------- genGraph :: Ord node => Float -> Gen edge -> [node] -> Gen (AdjList node edge) genGraph density edge nodes = do Map.fromList . concat <$> mapM neighbours nodes where k = round (100 * density) neighbours n = do ns <- concat <$> mapM neighbour nodes case ns of [] -> elements [[(n, [])], []] _ -> return [(n, ns)] neighbour n = frequency [ (k, do e <- edge ns <- neighbour n return ((n, e):ns)) , (100 - k, return []) ] newtype Distance = Dist Nat deriving (Eq, Ord, Num, Integral, Show, Enum, Real) instance SemiRing Distance where oplus (Dist a) (Dist b) = Dist (min a b) otimes (Dist a) (Dist b) = Dist (a + b) genGraph_ :: Nat -> Gen (AdjList Nat Distance) genGraph_ n = genGraph 0.2 (Dist <$> natural) [0..n - 1] lookupEdge :: Ord n => n -> n -> AdjList n e -> Maybe e lookupEdge i j g = lookup j =<< Map.lookup i g edges :: Ord n => AdjList n e -> [(n,n,e)] edges g = do (i, ns) <- Map.toList g (j, e) <- ns return (i, j, e) -- | Check that no edges get longer when completing a graph. prop_smaller :: Nat -> Property prop_smaller n' = forAll (genGraph_ n) $ \g -> let g' = warshallG g in and [ lookupEdge i j g' =< e | (i, j, e) <- edges g ] where n = abs (div n' 2) Nothing =< _ = False Just x =< y = x <= y newEdge :: Nat -> Nat -> Distance -> AdjList Nat Distance -> AdjList Nat Distance newEdge i j e = Map.insertWith (++) i [(j, e)] genPath :: Nat -> Nat -> Nat -> AdjList Nat Distance -> Gen (AdjList Nat Distance) genPath n i j g = do es <- listOf $ (,) <$> node <*> edge v <- edge return $ addPath i (es ++ [(j, v)]) g where edge :: Gen Distance edge = Dist <$> natural node :: Gen Nat node = choose (0, n - 1) addPath :: Nat -> [(Nat, Distance)] -> AdjList Nat Distance -> AdjList Nat Distance addPath _ [] g = g addPath i ((j, v):es) g = newEdge i j v $ addPath j es g -- | Check that all transitive edges are added. prop_path :: Nat -> Property prop_path n' = forAll (genGraph_ n) $ \g -> forAll (two $ choose (0, n - 1)) $ \(i, j) -> forAll (genPath n i j g) $ \g' -> isJust (lookupEdge i j $ warshallG g') where n = abs (div n' 2) + 1 mapNodes :: (Ord node, Ord node') => (node -> node') -> AdjList node edge -> AdjList node' edge mapNodes f = Map.map f' . Map.mapKeys f where f' es = [ (f n, e) | (n,e) <- es ] -- | Check that no edges are added between components. prop_disjoint :: Nat -> Property prop_disjoint n' = forAll (two $ genGraph_ n) $ \(g1, g2) -> let g = Map.union (mapNodes Left g1) (mapNodes Right g2) g' = warshallG g in all disjoint (Map.assocs g') where n = abs (div n' 3) disjoint (Left i, es) = all (isLeft . fst) es disjoint (Right i, es) = all (isRight . fst) es isLeft = either (const True) (const False) isRight = not . isLeft prop_stable :: Nat -> Property prop_stable n' = forAll (genGraph_ n) $ \g -> let g' = warshallG g in g' =~= warshallG g' where n = abs (div n' 2) g =~= g' = sort (edges g) == sort (edges g') tests :: IO Bool tests = runTests "Agda.Utils.Warshall" []