module Agda.Utils.Warshall where
import Control.Applicative
import Control.Monad.State
import Data.Maybe
import Data.Array
import Data.List
import Data.Map (Map)
import qualified Data.Map as Map
import Test.QuickCheck
import Agda.Utils.TestHelpers
import Agda.Syntax.Common
import Agda.Utils.QuickCheck
import Agda.Utils.SemiRing
import Debug.Trace
type Matrix a = Array (Int,Int) a
warshall :: SemiRing a => Matrix a -> Matrix a
warshall a0 = loop r a0 where
b@((r,c),(r',c')) = bounds a0
loop k a | k <= r' =
loop (k+1) (array b [ ((i,j),
(a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j))))
| i <- [r..r'], j <- [c..c'] ])
| otherwise = a
type AdjList node edge = Map node [(node, edge)]
warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge
warshallG g = fromMatrix $ warshall m
where
nodes = zip (nub $ Map.keys g ++ map fst (concat $ Map.elems g))
[0..]
len = length nodes
b = ((0,0), (len 1,len 1))
edge i j = do
es <- Map.lookup i g
foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ]
m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ]
fromMatrix matrix = Map.fromList $ do
(i, n) <- nodes
let es = [ (fst (nodes !! m), e)
| m <- [0..len 1]
, Just e <- [matrix ! (n, m)]
]
return (i, es)
data Weight = Finite Int | Infinite
deriving (Eq)
inc :: Weight -> Int -> Weight
inc Infinite n = Infinite
inc (Finite k) n = Finite (k + n)
instance Show Weight where
show (Finite i) = show i
show Infinite = "."
instance Ord Weight where
a <= Infinite = True
Infinite <= b = False
Finite a <= Finite b = a <= b
instance SemiRing Weight where
oplus = min
otimes Infinite _ = Infinite
otimes _ Infinite = Infinite
otimes (Finite a) (Finite b) = Finite (a + b)
data Node = Rigid Rigid
| Flex FlexId
deriving (Eq, Ord)
data Rigid = RConst Weight
| RVar RigidId
deriving (Eq, Ord, Show)
type NodeId = Int
type RigidId = Int
type FlexId = Int
type Scope = RigidId -> Bool
instance Show Node where
show (Flex i) = "?" ++ show i
show (Rigid (RVar i)) = "v" ++ show i
show (Rigid (RConst Infinite)) = "#"
show (Rigid (RConst (Finite n))) = show n
infinite (RConst Infinite) = True
infinite _ = False
isBelow :: Rigid -> Weight -> Rigid -> Bool
isBelow _ Infinite _ = True
isBelow _ n (RConst Infinite) = True
isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j
isBelow _ _ _ = False
data Constraint = NewFlex FlexId Scope
| Arc Node Int Node
instance Show Constraint where
show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")"
show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2
| k < 0 = show v1 ++ "+" ++ show (k) ++ "<=" ++ show v2
| otherwise = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k
type Constraints = [Constraint]
emptyConstraints = []
data Graph = Graph
{ flexScope :: Map FlexId Scope
, nodeMap :: Map Node NodeId
, intMap :: Map NodeId Node
, nextNode :: NodeId
, graph :: NodeId -> NodeId -> Weight
}
initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite)
type GM = State Graph
addFlex :: FlexId -> Scope -> GM ()
addFlex x scope = do
st <- get
put $ st { flexScope = Map.insert x scope (flexScope st) }
addNode (Flex x)
return ()
addNode :: Node -> GM Int
addNode n = do
st <- get
case Map.lookup n (nodeMap st) of
Just i -> return i
Nothing -> do let i = nextNode st
put $ st { nodeMap = Map.insert n i (nodeMap st)
, intMap = Map.insert i n (intMap st)
, nextNode = i + 1
}
return i
addEdge :: Node -> Int -> Node -> GM ()
addEdge n1 k n2 = do
i1 <- addNode n1
i2 <- addNode n2
st <- get
let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y
else graph st x y
put $ st { graph = graph' }
addConstraint :: Constraint -> GM ()
addConstraint (NewFlex x scope) = addFlex x scope
addConstraint (Arc n1 k n2) = addEdge n1 k n2
buildGraph :: Constraints -> Graph
buildGraph cs = execState (mapM_ addConstraint cs) initGraph
mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight
mkMatrix n g = array ((0,0),(n1,n1))
[ ((i,j), g i j) | i <- [0..n1], j <- [0..n1]]
data LegendMatrix a b c = LegendMatrix
{ matrix :: Matrix a
, rowdescr :: Int -> b
, coldescr :: Int -> c
}
instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where
show (LegendMatrix m rd cd) =
let ((r,c),(r',c')) = bounds m
in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++
foldr (\ i s -> "\n" ++ show (rd i) ++
foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t)
(s)
[c .. c'])
"" [r .. r']
type Solution = Map Int SizeExpr
emptySolution = Map.empty
extendSolution subst k v = Map.insert k v subst
data SizeExpr = SizeVar RigidId Int
| SizeConst Weight
instance Show SizeExpr where
show (SizeVar n 0) = show (Rigid (RVar n))
show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k
show (SizeConst w) = show w
sizeRigid :: Rigid -> Int -> SizeExpr
sizeRigid (RConst k) n = SizeConst (inc k n)
sizeRigid (RVar i) n = SizeVar i n
solve :: Constraints -> Maybe Solution
solve cs =
let solution = if solvable then loop1 flexs rigids emptySolution
else Nothing
in
solution
where
gr = buildGraph cs
n = nextNode gr
m0 = mkMatrix n (graph gr)
m = warshall m0
legend i = fromJust $ Map.lookup i (intMap gr)
lm0 = LegendMatrix m0 legend legend
lm = LegendMatrix m legend legend
ns = Map.keys (nodeMap gr)
flexs = foldl (\ l k -> case k of (Flex i) -> i : l
(Rigid _) -> l) [] ns
rigids = foldl (\ l k -> case k of (Flex _) -> l
(Rigid i) -> i : l) [] ns
rInds = foldl (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr)
in i : l) [] rigids
solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] &&
all (\ (r, r') -> r == r' ||
let Just row = (Map.lookup (Rigid r) (nodeMap gr))
Just col = (Map.lookup (Rigid r') (nodeMap gr))
edge = m!(row,col)
in isBelow r edge r' )
[ (r,r') | r <- rigids, r' <- rigids ]
&&
all (\ (x, v) ->
let Just row = (Map.lookup (Flex x) (nodeMap gr))
Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr))
edge = m!(row,col)
in edge >= Finite 0)
[ (x,v) | x <- flexs, (RVar v) <- rigids ]
inScope :: FlexId -> Rigid -> Bool
inScope x (RConst _) = True
inScope x (RVar v) = scope v
where Just scope = Map.lookup x (flexScope gr)
loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution
loop1 [] rgds subst = Just subst
loop1 flxs [] subst = loop2 flxs subst
loop1 flxs (r:rgds) subst =
let row = fromJust $ Map.lookup (Rigid r) (nodeMap gr)
(flxs',subst') =
foldl (\ (flx,sub) f ->
let col = fromJust $ Map.lookup (Flex f) (nodeMap gr)
in case (inScope f r, m!(row,col)) of
(True, Finite z) ->
let trunc z | z >= 0 = 0
| otherwise = z
in (flx, extendSolution sub f (sizeRigid r (trunc z)))
_ -> (f : flx, sub)
) ([], subst) flxs
in loop1 flxs' rgds subst'
loop2 :: [FlexId] -> Solution -> Maybe Solution
loop2 [] subst = Just subst
loop2 (f:flxs) subst = loop3 0 subst
where row = fromJust $ Map.lookup (Flex f) (nodeMap gr)
loop3 col subst | col >= n =
loop2 flxs (extendSolution subst f (SizeConst Infinite))
loop3 col subst =
case Map.lookup col (intMap gr) of
Just (Rigid r) | not (infinite r) ->
case (inScope f r, m!(row,col)) of
(True, Finite z) | z >= 0 ->
loop2 flxs (extendSolution subst f (sizeRigid r z))
(_, Infinite) -> loop3 (col+1) subst
_ ->
Nothing
_ -> loop3 (col+1) subst
genGraph :: Ord node => Float -> Gen edge -> [node] -> Gen (AdjList node edge)
genGraph density edge nodes = do
Map.fromList . concat <$> mapM neighbours nodes
where
k = round (100 * density)
neighbours n = do
ns <- concat <$> mapM neighbour nodes
case ns of
[] -> elements [[(n, [])], []]
_ -> return [(n, ns)]
neighbour n = frequency
[ (k, do e <- edge
ns <- neighbour n
return ((n, e):ns))
, (100 k, return [])
]
newtype Distance = Dist Nat
deriving (Eq, Ord, Num, Integral, Show, Enum, Real)
instance SemiRing Distance where
oplus (Dist a) (Dist b) = Dist (min a b)
otimes (Dist a) (Dist b) = Dist (a + b)
genGraph_ :: Nat -> Gen (AdjList Nat Distance)
genGraph_ n =
genGraph 0.2 (Dist <$> natural) [0..n 1]
lookupEdge :: Ord n => n -> n -> AdjList n e -> Maybe e
lookupEdge i j g = lookup j =<< Map.lookup i g
edges :: Ord n => AdjList n e -> [(n,n,e)]
edges g = do
(i, ns) <- Map.toList g
(j, e) <- ns
return (i, j, e)
prop_smaller n' =
forAll (genGraph_ n) $ \g ->
let g' = warshallG g in
and [ lookupEdge i j g' =< e
| (i, j, e) <- edges g
]
where
n = abs (div n' 2)
Nothing =< _ = False
Just x =< y = x <= y
newEdge i j e = Map.insertWith (++) i [(j, e)]
genPath :: Nat -> Nat -> Nat -> AdjList Nat Distance -> Gen (AdjList Nat Distance)
genPath n i j g = do
es <- listOf $ (,) <$> node <*> edge
v <- edge
return $ addPath i (es ++ [(j, v)]) g
where
edge = Dist <$> natural
node = choose (0, n 1)
addPath _ [] g = g
addPath i ((j, v):es) g =
newEdge i j v $ addPath j es g
prop_path n' =
forAll (genGraph_ n) $ \g ->
forAll (two $ choose (0, n 1)) $ \(i, j) ->
forAll (genPath n i j g) $ \g' ->
isJust (lookupEdge i j $ warshallG g')
where
n = abs (div n' 2) + 1
mapNodes :: (Ord node, Ord node') => (node -> node') -> AdjList node edge -> AdjList node' edge
mapNodes f = Map.map f' . Map.mapKeys f
where
f' es = [ (f n, e) | (n,e) <- es ]
prop_disjoint n' =
forAll (two $ genGraph_ n) $ \(g1, g2) ->
let g = Map.union (mapNodes Left g1) (mapNodes Right g2)
g' = warshallG g
in all disjoint (Map.assocs g')
where
n = abs (div n' 3)
disjoint (Left i, es) = all (isLeft . fst) es
disjoint (Right i, es) = all (isRight . fst) es
isLeft = either (const True) (const False)
isRight = not . isLeft
prop_stable n' =
forAll (genGraph_ n) $ \g ->
let g' = warshallG g in
g' =~= warshallG g'
where
n = abs (div n' 2)
g =~= g' = sort (edges g) == sort (edges g')
tests :: IO Bool
tests = runTests "Agda.Utils.Warshall"
[ quickCheck' prop_smaller
, quickCheck' prop_path
, quickCheck' prop_disjoint
, quickCheck' prop_stable
]