```{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Agda.Utils.Warshall where

{- construct a graph from constraints

x + n <= y   becomes   x ---(-n)---> y
x <= n + y   becomes   x ---(+n)---> y

the default edge (= no edge is) labelled with infinity

building the graph involves keeping track of the node names.
We do this in a finite map, assigning consecutive numbers to nodes.
-}

import Control.Applicative
import Control.Monad.State
import Data.Maybe -- fromJust
import Data.Array
import Data.List
import Data.Map (Map)
import qualified Data.Map as Map
import Test.QuickCheck
import Agda.Utils.TestHelpers
import Agda.Syntax.Common
import Agda.Utils.QuickCheck
import Agda.Utils.SemiRing

import Debug.Trace

type Matrix a = Array (Int,Int) a

-- assuming a square matrix
warshall :: SemiRing a => Matrix a -> Matrix a
warshall a0 = loop r a0 where
b@((r,c),(r',c')) = bounds a0 -- assuming r == c and r' == c'
loop k a | k <= r' =
loop (k+1) (array b [ ((i,j),
(a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j))))
| i <- [r..r'], j <- [c..c'] ])
| otherwise = a

-- Warshall's algorithm on a graph represented as an adjacency list.
type AdjList node edge = Map node [(node, edge)]

warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge
warshallG g = fromMatrix \$ warshall m
where
nodes = zip (nub \$ Map.keys g ++ map fst (concat \$ Map.elems g))
[0..]
len   = length nodes
b     = ((0,0), (len - 1,len - 1))

edge i j = do
es <- Map.lookup i g
foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ]

m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ]

fromMatrix matrix = Map.fromList \$ do
(i, n) <- nodes
let es = [ (fst (nodes !! m), e)
| m <- [0..len - 1]
, Just e <- [matrix ! (n, m)]
]
return (i, es)

-- edge weight in the graph, forming a semi ring

data Weight = Finite Int | Infinite
deriving (Eq)

inc :: Weight -> Int -> Weight
inc Infinite   n = Infinite
inc (Finite k) n = Finite (k + n)

instance Show Weight where
show (Finite i) = show i
show Infinite   = "."

instance Ord Weight where
a <= Infinite = True
Infinite <= b = False
Finite a <= Finite b = a <= b

instance SemiRing Weight where
oplus = min

otimes Infinite _ = Infinite
otimes _ Infinite = Infinite
otimes (Finite a) (Finite b) = Finite (a + b)

-- constraints ---------------------------------------------------

-- nodes of the graph are either
-- * flexible variables (with identifiers drawn from Int),
-- * rigid variables (also identified by Ints), or
-- * constants (like 0, infinity, or anything between)

data Node = Rigid Rigid
| Flex  FlexId
deriving (Eq, Ord)

data Rigid = RConst Weight
| RVar RigidId
deriving (Eq, Ord, Show)

type NodeId  = Int
type RigidId = Int
type FlexId  = Int
type Scope   = RigidId -> Bool
-- which rigid variables a flex may be instatiated to

instance Show Node where
show (Flex  i) = "?" ++ show i
show (Rigid (RVar i)) = "v" ++ show i
show (Rigid (RConst Infinite))   = "#"
show (Rigid (RConst (Finite n))) = show n

infinite (RConst Infinite) = True
infinite _ = False

-- isBelow r w r'
-- checks, if r and r' are connected by w (meaning w not infinite)
-- wether r + w <= r'
-- precondition: not the same rigid variable
isBelow :: Rigid -> Weight -> Rigid -> Bool
isBelow _ Infinite _ = True
isBelow _ n (RConst Infinite) = True
-- isBelow (RConst Infinite)   n (RConst (Finite _)) = False
isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j
isBelow _ _ _ = False -- rigid variables are not related

-- a constraint is an edge in the graph
data Constraint = NewFlex FlexId Scope
| Arc Node Int Node
-- Arc v1 k v2  at least one of v1,v2 is a VMeta (Flex),
--              the other a VMeta or a VGen (Rigid)
-- if k <= 0 this means  \$^(-k) v1 <= v2
-- otherwise                    v1 <= \$^k v3

instance Show Constraint where
show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")"
show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2
| k < 0  = show v1 ++ "+" ++ show (-k) ++ "<=" ++ show v2
| otherwise  = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k

type Constraints = [Constraint]

emptyConstraints = []

-- graph (matrix) ------------------------------------------------

data Graph = Graph
{ flexScope :: Map FlexId Scope        -- scope for each flexible var
, nodeMap :: Map Node NodeId           -- node labels to node numbers
, intMap  :: Map NodeId Node           -- node numbers to node labels
, nextNode :: NodeId                   -- number of nodes (n)
, graph :: NodeId -> NodeId -> Weight  -- the edges (restrict to [0..n[)
}

-- the empty graph: no nodes, edges are all undefined (infinity weight)
initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite)

-- the Graph Monad, for constructing a graph iteratively
type GM = State Graph

addFlex :: FlexId -> Scope -> GM ()
addFlex x scope = do
st <- get
put \$ st { flexScope = Map.insert x scope (flexScope st) }
addNode (Flex x)
return ()

-- i <- addNode n  returns number of node n. if not present, it is added first
addNode :: Node -> GM Int
addNode n = do
st <- get
case Map.lookup n (nodeMap st) of
Just i -> return i
Nothing -> do let i = nextNode st
put \$ st { nodeMap = Map.insert n i (nodeMap st)
, intMap = Map.insert i n (intMap st)
, nextNode = i + 1
}
return i

-- addEdge n1 k n2
-- improves the weight of egde n1->n2 to be at most k
-- also adds nodes if not yet present
addEdge :: Node -> Int -> Node -> GM ()
addEdge n1 k n2 = do
i1 <- addNode n1
i2 <- addNode n2
st <- get
let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y
else graph st x y
put \$ st { graph = graph' }

addConstraint :: Constraint -> GM ()
addConstraint (NewFlex x scope) = addFlex x scope
addConstraint (Arc n1 k n2)     = addEdge n1 k n2

buildGraph :: Constraints -> Graph
buildGraph cs = execState (mapM_ addConstraint cs) initGraph

mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight
mkMatrix n g = array ((0,0),(n-1,n-1))
[ ((i,j), g i j) | i <- [0..n-1], j <- [0..n-1]]

-- displaying matrices with row and column labels --------------------

-- a matrix with row descriptions in b and column descriptions in c
data LegendMatrix a b c = LegendMatrix
{ matrix   :: Matrix a
, rowdescr :: Int -> b
, coldescr :: Int -> c
}

instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where
show (LegendMatrix m rd cd) =
-- first show column description
let ((r,c),(r',c')) = bounds m
in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++
-- then output rows
foldr (\ i s -> "\n" ++ show (rd i) ++
foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t)
(s)
[c .. c'])
"" [r .. r']

-- solving the constraints -------------------------------------------

-- a solution assigns to each flexible variable a size expression
-- which is either a constant or a v + n for a rigid variable v
type Solution = Map Int SizeExpr

emptySolution = Map.empty
extendSolution subst k v = Map.insert k v subst

data SizeExpr = SizeVar Int Int   -- e.g. x + 5
| SizeConst Weight  -- a number or infinity

instance Show SizeExpr where
show (SizeVar n 0) = show (Rigid (RVar n))
show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k
show (SizeConst w) = show w

-- sizeRigid r n  returns the size expression corresponding to r + n
sizeRigid :: Rigid -> Int -> SizeExpr
sizeRigid (RConst k) n = SizeConst (inc k n)
sizeRigid (RVar i)   n = SizeVar i n

{-
apply :: SizeExpr -> Solution -> SizeExpr
apply e@(SizeExpr (Rigid _) _) phi = e
apply e@(SizeExpr (Flex  x) i) phi = case Map.lookup x phi of
Nothing -> e
Just (SizeExpr v j) -> SizeExpr v (i + j)

after :: Solution -> Solution -> Solution
after psi phi = Map.map (\ e -> e `apply` phi) psi
-}

{-
solve :: Constraints -> Maybe Solution
solve cs = if any (\ x -> x < Finite 0) d then Nothing
else Map.
where gr = buildGraph cs
n  = nextNode gr
m  = mkMatrix n (graph gr)
m' = warshall m
d  = [ m!(i,i) | i <- [0 .. (n-1)] ]
ns = keys (nodeMap gr)
-}

{- compute solution

a solution CANNOT exist if

v < v  for a rigid variable v

v <= v' for rigid variables v,v'

x < v   for a flexible variable x and a rigid variable v

thus, for each flexible x, only one of the following cases is possible

r+n <= x+m <= infty  for a unique rigid r  (meaning r --(m-n)--> x)
x <= r+n             for a unique rigid r  (meaning x --(n)--> r)

we are looking for the least values for flexible variables that solve
the constraints.  Algorithm

while flexible variables and rigid rows left
find a rigid variable row i
for all flexible columns j
if i --n--> j with n<=0 (meaning i+n <= j) then j = i + n

while flexible variables j left
search the row j for entry i
if j --n--> i with n >= 0 (meaning j <= i + n) then j = i + n

-}

solve :: Constraints -> Maybe Solution
solve cs = -- trace (show cs) \$
-- trace (show lm0) \$
-- trace (show lm) \$ -- trace (show d) \$
let solution = if solvable then loop1 flexs rigids emptySolution
else Nothing
in -- trace (show solution) \$
solution
where -- compute the graph and its transitive closure m
gr  = buildGraph cs
n   = nextNode gr            -- number of nodes
m0  = mkMatrix n (graph gr)
m   = warshall m0

-- tracing only: build output version of transitive graph
legend i = fromJust \$ Map.lookup i (intMap gr) -- trace only
lm0 = LegendMatrix m0 legend legend            -- trace only
lm  = LegendMatrix m legend legend             -- trace only

-- compute the sets of flexible and rigid node numbers
ns  = Map.keys (nodeMap gr)
-- a set of flexible variables
flexs  = foldl (\ l k -> case k of (Flex i) -> i : l
(Rigid _) -> l) [] ns
-- a set of rigid variables
rigids = foldl (\ l k -> case k of (Flex _) -> l
(Rigid i) -> i : l) [] ns

-- rigid matrix indices
rInds = foldl (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr)
in i : l) [] rigids

-- check whether there is a solution
-- d   = [ m!(i,i) | i <- [0 .. (n-1)] ]  -- diagonal
-- a rigid variable might not be less than it self, so no -.. on the
-- rigid part of the diagonal
solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] &&
-- a rigid variable might not be bounded below by infinity or
-- bounded above by a constant
-- it might not be related to another rigid variable
all (\ (r,  r') -> r == r' ||
let Just row = (Map.lookup (Rigid r)  (nodeMap gr))
Just col = (Map.lookup (Rigid r') (nodeMap gr))
edge = m!(row,col)
in  isBelow r edge r' )
[ (r,r') | r <- rigids, r' <- rigids ]
&&
-- a flexible variable might not be strictly below a rigid variable
all (\ (x, v) ->
let Just row = (Map.lookup (Flex x)  (nodeMap gr))
Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr))
edge = m!(row,col)
in  edge >= Finite 0)
[ (x,v) | x <- flexs, (RVar v) <- rigids ]

inScope :: FlexId -> Rigid -> Bool
inScope x (RConst _) = True
inScope x (RVar v)   = scope v
where Just scope = Map.lookup x (flexScope gr)

{- loop1

while flexible variables and rigid rows left
find a rigid variable row i
for all flexible columns j
if i --n--> j with n<=0 (meaning i + n <= j) then j = i + n

-}
loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution
loop1 [] rgds subst = Just subst
loop1 flxs [] subst = loop2 flxs subst
loop1 flxs (r:rgds) subst =
let row = fromJust \$ Map.lookup (Rigid r) (nodeMap gr)
(flxs',subst') =
foldl (\ (flx,sub) f ->
let col = fromJust \$ Map.lookup (Flex f) (nodeMap gr)
in  case (inScope f r, m!(row,col)) of
--                                Finite z | z <= 0 ->
(True, Finite z) ->
let trunc z | z >= 0 = 0
| otherwise = -z
in (flx, extendSolution sub f (sizeRigid r (trunc z)))
_ -> (f : flx, sub)
) ([], subst) flxs
in loop1 flxs' rgds subst'

{- loop2

while flexible variables j left
search the row j for entry i
if j --n--> i with n >= 0 (meaning j <= i + n) then j = i

-}
loop2 :: [FlexId] -> Solution -> Maybe Solution
loop2 [] subst = Just subst
loop2 (f:flxs) subst = loop3 0 subst
where row = fromJust \$ Map.lookup (Flex f) (nodeMap gr)
loop3 col subst | col >= n =
-- default to infinity
loop2 flxs (extendSolution subst f (SizeConst Infinite))
loop3 col subst =
case Map.lookup col (intMap gr) of
Just (Rigid r) | not (infinite r) ->
case (inScope f r, m!(row,col)) of
(True, Finite z) | z >= 0 ->
loop2 flxs (extendSolution subst f (sizeRigid r z))
(_, Infinite) -> loop3 (col+1) subst
_ -> -- trace ("unusable rigid: " ++ show r ++ " for flex " ++ show f)
Nothing  -- NOT: loop3 (col+1) subst
_ -> loop3 (col+1) subst

-- Testing ----------------------------------------------------------------

genGraph :: Ord node => Float -> Gen edge -> [node] -> Gen (AdjList node edge)
genGraph density edge nodes = do
Map.fromList . concat <\$> mapM neighbours nodes
where
k = round (100 * density)
neighbours n = do
ns <- concat <\$> mapM neighbour nodes
case ns of
[]  -> elements [[(n, [])], []]
_   -> return [(n, ns)]
neighbour n = frequency
[ (k, do e <- edge
ns <- neighbour n
return ((n, e):ns))
, (100 - k, return [])
]

newtype Distance = Dist Nat
deriving (Eq, Ord, Num, Integral, Show, Enum, Real)

instance SemiRing Distance where
oplus  (Dist a) (Dist b) = Dist (min a b)
otimes (Dist a) (Dist b) = Dist (a + b)

genGraph_ :: Nat -> Gen (AdjList Nat Distance)
genGraph_ n =
genGraph 0.2 (Dist <\$> natural) [0..n - 1]

lookupEdge :: Ord n => n -> n -> AdjList n e -> Maybe e
lookupEdge i j g = lookup j =<< Map.lookup i g

edges :: Ord n => AdjList n e -> [(n,n,e)]
edges g = do
(i, ns) <- Map.toList g
(j, e)  <- ns
return (i, j, e)

-- | Check that no edges get longer when completing a graph.
prop_smaller n' =
forAll (genGraph_ n) \$ \g ->
let g' = warshallG g in
and [ lookupEdge i j g' =< e
| (i, j, e) <- edges g
]
where
n = abs (div n' 2)
Nothing =< _ = False
Just x  =< y = x <= y

newEdge i j e = Map.insertWith (++) i [(j, e)]

genPath :: Nat -> Nat -> Nat -> AdjList Nat Distance -> Gen (AdjList Nat Distance)
genPath n i j g = do
es <- listOf \$ (,) <\$> node <*> edge
v  <- edge
return \$ addPath i (es ++ [(j, v)]) g
where
edge = Dist <\$> natural
node = choose (0, n - 1)
addPath _ [] g = g
addPath i ((j, v):es) g =
newEdge i j v \$ addPath j es g

-- | Check that all transitive edges are added.
prop_path n' =
forAll (genGraph_ n) \$ \g ->
forAll (replicateM 2 \$ choose (0, n - 1)) \$ \[i, j] ->
forAll (genPath n i j g) \$ \g' ->
isJust (lookupEdge i j \$ warshallG g')
where
n = abs (div n' 2) + 1

mapNodes :: (Ord node, Ord node') => (node -> node') -> AdjList node edge -> AdjList node' edge
mapNodes f = Map.map f' . Map.mapKeys f
where
f' es = [ (f n, e) | (n,e) <- es ]

-- | Check that no edges are added between components.
prop_disjoint n' =
forAll (replicateM 2 \$ genGraph_ n) \$ \[g1, g2] ->
let g  = Map.union (mapNodes Left g1) (mapNodes Right g2)
g' = warshallG g
in all disjoint (Map.assocs g')
where
n = abs (div n' 3)
disjoint (Left i, es)  = all (isLeft . fst) es
disjoint (Right i, es) = all (isRight . fst) es
isLeft = either (const True) (const False)
isRight = not . isLeft

prop_stable n' =
forAll (genGraph_ n) \$ \g ->
let g' = warshallG g in
g' =~= warshallG g'
where
n = abs (div n' 2)
g =~= g' = sort (edges g) == sort (edges g')

tests :: IO Bool
tests = runTests "Agda.Utils.Warshall"
[ quickCheck' prop_smaller
, quickCheck' prop_path
, quickCheck' prop_disjoint
, quickCheck' prop_stable
]
```