```{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{- | Construct a graph from constraints
@
x + n <= y   becomes   x ---(-n)---> y
x <= n + y   becomes   x ---(+n)---> y
@
the default edge (= no edge) is labelled with infinity.

Building the graph involves keeping track of the node names.
We do this in a finite map, assigning consecutive numbers to nodes.
-}
module Agda.Utils.Warshall where

import Data.Maybe
import Data.Array
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map

import Agda.Syntax.Common (Nat)
import Agda.Utils.SemiRing

type Matrix a = Array (Int,Int) a

-- assuming a square matrix
warshall :: SemiRing a => Matrix a -> Matrix a
warshall a0 = loop r a0 where
b@((r,c),(r',c')) = bounds a0 -- assuming r == c and r' == c'
loop k a | k <= r' =
loop (k+1) (array b [ ((i,j),
(a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j))))
| i <- [r..r'], j <- [c..c'] ])
| otherwise = a

type AdjList node edge = Map node [(node, edge)]

-- | Warshall's algorithm on a graph represented as an adjacency list.
warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge
warshallG g = fromMatrix \$ warshall m
where
nodes = zip (List.nub \$ Map.keys g ++ map fst (concat \$ Map.elems g))
[0..]
len   = length nodes
b     = ((0,0), (len - 1,len - 1))

edge i j = do
es <- Map.lookup i g
foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ]

m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ]

fromMatrix matrix = Map.fromList \$ do
(i, n) <- nodes
let es = [ (fst (nodes !! m), e)
| m <- [0..len - 1]
, Just e <- [matrix ! (n, m)]
]
return (i, es)

-- | Edge weight in the graph, forming a semi ring.
data Weight = Finite Int | Infinite
deriving (Eq)

inc :: Weight -> Int -> Weight
inc Infinite   n = Infinite
inc (Finite k) n = Finite (k + n)

instance Show Weight where
show (Finite i) = show i
show Infinite   = "."

instance Ord Weight where
a <= Infinite = True
Infinite <= b = False
Finite a <= Finite b = a <= b

instance SemiRing Weight where
ozero = Infinite
oone  = Finite 0

oplus = min

otimes Infinite _ = Infinite
otimes _ Infinite = Infinite
otimes (Finite a) (Finite b) = Finite (a + b)

-- constraints ---------------------------------------------------

-- | Nodes of the graph are either
-- - flexible variables (with identifiers drawn from @Int@),
-- - rigid variables (also identified by @Int@s), or
-- - constants (like 0, infinity, or anything between).

data Node = Rigid Rigid
| Flex  FlexId
deriving (Eq, Ord)

data Rigid = RConst Weight
| RVar RigidId
deriving (Eq, Ord, Show)

type NodeId  = Int
type RigidId = Int
type FlexId  = Int
type Scope   = RigidId -> Bool
-- ^ Which rigid variables a flex may be instatiated to.

instance Show Node where
show (Flex  i) = "?" ++ show i
show (Rigid (RVar i)) = "v" ++ show i
show (Rigid (RConst Infinite))   = "#"
show (Rigid (RConst (Finite n))) = show n

infinite :: Rigid -> Bool
infinite (RConst Infinite) = True
infinite _                 = False

-- | @isBelow r w r'@
--   checks, if @r@ and @r'@ are connected by @w@ (meaning @w@ not infinite),
--   whether @r + w <= r'@.
--   Precondition: not the same rigid variable.
isBelow :: Rigid -> Weight -> Rigid -> Bool
isBelow _ Infinite _ = True
isBelow _ n (RConst Infinite) = True
isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j
isBelow _ _ _ = False -- rigid variables are not related

-- | A constraint is an edge in the graph.
data Constraint
= NewFlex FlexId Scope
| Arc Node Int Node
-- ^ For @Arc v1 k v2@  at least one of @v1@ or @v2@ is a @MetaV@ (Flex),
--                      the other a @MetaV@ or a @Var@ (Rigid).
--   If @k <= 0@ this means  @suc^(-k) v1 <= v2@
--   otherwise               @v1 <= suc^k v3@.

instance Show Constraint where
show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")"
show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2
| k < 0  = show v1 ++ "+" ++ show (-k) ++ "<=" ++ show v2
| otherwise  = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k

type Constraints = [Constraint]

emptyConstraints :: Constraints
emptyConstraints = []

-- graph (matrix) ------------------------------------------------

data Graph = Graph
{ flexScope :: Map FlexId Scope           -- ^ Scope for each flexible var.
, nodeMap   :: Map Node NodeId            -- ^ Node labels to node numbers.
, intMap    :: Map NodeId Node            -- ^ Node numbers to node labels.
, nextNode  :: NodeId                     -- ^ Number of nodes @n@.
, graph     :: NodeId -> NodeId -> Weight -- ^ The edges (restrict to @[0..n[@).
}

-- | The empty graph: no nodes, edges are all undefined (infinity weight).
initGraph :: Graph
initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite)

-- | The Graph Monad, for constructing a graph iteratively.
type GM = State Graph

-- | Add a size meta node.
addFlex :: FlexId -> Scope -> GM ()
modify \$ \ st -> st { flexScope = Map.insert x scope (flexScope st) }
return ()

-- | Lookup identifier of a node.
--   If not present, it is added first.
addNode :: Node -> GM Int
st <- get
case Map.lookup n (nodeMap st) of
Just i -> return i
Nothing -> do
let i = nextNode st
put \$ st { nodeMap  = Map.insert n i (nodeMap st)
, intMap   = Map.insert i n (intMap st)
, nextNode = i + 1
}
return i

-- | @addEdge n1 k n2@
--   improves the weight of egde @n1->n2@ to be at most @k@.
--   Also adds nodes if not yet present.
addEdge :: Node -> Int -> Node -> GM ()
addEdge n1 k n2 = do
st <- get
let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y
else graph st x y
put \$ st { graph = graph' }

addConstraint :: Constraint -> GM ()

buildGraph :: Constraints -> Graph
buildGraph cs = execState (mapM_ addConstraint cs) initGraph

mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight
mkMatrix n g = array ((0,0),(n-1,n-1))
[ ((i,j), g i j) | i <- [0..n-1], j <- [0..n-1]]

-- displaying matrices with row and column labels --------------------

-- | A matrix with row descriptions in @b@ and column descriptions in @c@.
data LegendMatrix a b c = LegendMatrix
{ matrix   :: Matrix a
, rowdescr :: Int -> b
, coldescr :: Int -> c
}

instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where
show (LegendMatrix m rd cd) =
-- first show column description
let ((r,c),(r',c')) = bounds m
in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++
-- then output rows
foldr (\ i s -> "\n" ++ show (rd i) ++
foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t)
(s)
[c .. c'])
"" [r .. r']

-- solving the constraints -------------------------------------------

-- | A solution assigns to each flexible variable a size expression
--   which is either a constant or a @v + n@ for a rigid variable @v@.
type Solution = Map Int SizeExpr

emptySolution :: Solution
emptySolution = Map.empty

extendSolution :: Solution -> Int -> SizeExpr -> Solution
extendSolution subst k v = Map.insert k v subst

data SizeExpr = SizeVar RigidId Int   -- ^ e.g. x + 5
| SizeConst Weight      -- ^ a number or infinity

instance Show SizeExpr where
show (SizeVar n 0) = show (Rigid (RVar n))
show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k
show (SizeConst w) = show w

-- | @sizeRigid r n@ returns the size expression corresponding to @r + n@
sizeRigid :: Rigid -> Int -> SizeExpr
sizeRigid (RConst k) n = SizeConst (inc k n)
sizeRigid (RVar i)   n = SizeVar i n

{-
apply :: SizeExpr -> Solution -> SizeExpr
apply e@(SizeExpr (Rigid _) _) phi = e
apply e@(SizeExpr (Flex  x) i) phi = case Map.lookup x phi of
Nothing -> e
Just (SizeExpr v j) -> SizeExpr v (i + j)

after :: Solution -> Solution -> Solution
after psi phi = Map.map (\ e -> e `apply` phi) psi
-}

{- compute solution

a solution CANNOT exist if

v < v  for a rigid variable v

-- Andreas, 2012-09-19 OUTDATED are:

-- v <= v' for rigid variables v,v'

-- x < v   for a flexible variable x and a rigid variable v

thus, for each flexible x, only one of the following cases is possible

r+n <= x+m <= infty  for a unique rigid r  (meaning r --(m-n)--> x)
x <= r+n             for a unique rigid r  (meaning x --(n)--> r)

we are looking for the least values for flexible variables that solve
the constraints.  Algorithm

while flexible variables and rigid rows left
find a rigid variable row i
for all flexible columns j
if i --n--> j with n<=0 (meaning i+n <= j) then j = i + n

while flexible variables j left
search the row j for entry i
if j --n--> i with n >= 0 (meaning j <= i + n) then j = i + n

-}

solve :: Constraints -> Maybe Solution
solve cs = -- trace (show cs) \$
-- trace (show lm0) \$
-- trace (show lm) \$ -- trace (show d) \$
let solution = if solvable then loop1 flexs rigids emptySolution
else Nothing
in -- trace (show solution) \$
solution
where -- compute the graph and its transitive closure m
gr  = buildGraph cs
n   = nextNode gr            -- number of nodes
m0  = mkMatrix n (graph gr)
m   = warshall m0

-- tracing only: build output version of transitive graph
legend i = fromJust \$ Map.lookup i (intMap gr) -- trace only
lm0 = LegendMatrix m0 legend legend            -- trace only
lm  = LegendMatrix m legend legend             -- trace only

-- compute the sets of flexible and rigid node numbers
ns  = Map.keys (nodeMap gr)
-- a set of flexible variables
flexs  = List.foldl' (\ l -> \case (Flex i ) -> i : l
(Rigid _) -> l)     [] ns
-- a set of rigid variables
rigids = List.foldl' (\ l -> \case (Flex _ ) -> l
(Rigid i) -> i : l) [] ns

-- rigid matrix indices
rInds = List.foldl' (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr)
in i : l) [] rigids

-- check whether there is a solution
-- d   = [ m!(i,i) | i <- [0 .. (n-1)] ]  -- diagonal
-- a rigid variable might not be less than it self, so no -.. on the
-- rigid part of the diagonal
solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] && True

{-  Andreas, 2012-09-19
We now can have constraints between rigid variables, like i < j.
Thus we skip the following two test.  However, a solution must be
checked for consistency with the constraints on rigid vars.

-- a rigid variable might not be bounded below by infinity or
-- bounded above by a constant
-- it might not be related to another rigid variable
all (\ (r,  r') -> r == r' ||
let Just row = (Map.lookup (Rigid r)  (nodeMap gr))
Just col = (Map.lookup (Rigid r') (nodeMap gr))
edge = m!(row,col)
in  isBelow r edge r' )
[ (r,r') | r <- rigids, r' <- rigids ]
&&
-- a flexible variable might not be strictly below a rigid variable
all (\ (x, v) ->
let Just row = (Map.lookup (Flex x)  (nodeMap gr))
Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr))
edge = m!(row,col)
in  edge >= Finite 0)
[ (x,v) | x <- flexs, (RVar v) <- rigids ]
-}

inScope :: FlexId -> Rigid -> Bool
inScope x (RConst _) = True
inScope x (RVar v)   = scope v
where Just scope = Map.lookup x (flexScope gr)

{- loop1

while flexible variables and rigid rows left
find a rigid variable row i
for all flexible columns j
if i --n--> j with n<=0 (meaning i + n <= j) then j = i + n

-}
loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution
loop1 [] rgds subst = Just subst
loop1 flxs [] subst = loop2 flxs subst
loop1 flxs (r:rgds) subst =
let row = fromJust \$ Map.lookup (Rigid r) (nodeMap gr)
(flxs',subst') =
List.foldl' (\ (flx,sub) f ->
let col = fromJust \$ Map.lookup (Flex f) (nodeMap gr)
in  case (inScope f r, m!(row,col)) of
--                                Finite z | z <= 0 ->
(True, Finite z) ->
let trunc z | z >= 0 = 0
| otherwise = -z
in (flx, extendSolution sub f (sizeRigid r (trunc z)))
_ -> (f : flx, sub)
) ([], subst) flxs
in loop1 flxs' rgds subst'

{- loop2

while flexible variables j left
search the row j for entry i
if j --n--> i with n >= 0 (meaning j <= i + n) then j = i

-}
loop2 :: [FlexId] -> Solution -> Maybe Solution
loop2 [] subst = Just subst
loop2 (f:flxs) subst = loop3 0 subst
where row = fromJust \$ Map.lookup (Flex f) (nodeMap gr)
loop3 col subst | col >= n =
-- default to infinity
loop2 flxs (extendSolution subst f (SizeConst Infinite))
loop3 col subst =
case Map.lookup col (intMap gr) of
Just (Rigid r) | not (infinite r) ->
case (inScope f r, m!(row,col)) of
(True, Finite z) | z >= 0 ->
loop2 flxs (extendSolution subst f (sizeRigid r z))
(_, Infinite) -> loop3 (col+1) subst
_ -> -- trace ("unusable rigid: " ++ show r ++ " for flex " ++ show f)
Nothing  -- NOT: loop3 (col+1) subst
_ -> loop3 (col+1) subst
```