{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{- | Construct a graph from constraints
@
   x + n <= y   becomes   x ---(-n)---> y
   x <= n + y   becomes   x ---(+n)---> y
@
the default edge (= no edge) is labelled with infinity.

Building the graph involves keeping track of the node names.
We do this in a finite map, assigning consecutive numbers to nodes.
-}
module Agda.Utils.Warshall where

import Control.Applicative
import Control.Monad.State

import Data.Maybe
import Data.Array
import Data.List
import Data.Map (Map)
import qualified Data.Map as Map

import Test.QuickCheck

import Agda.Syntax.Common (Nat)
import Agda.Utils.TestHelpers
import Agda.Utils.SemiRing

type Matrix a = Array (Int,Int) a

-- assuming a square matrix
warshall :: SemiRing a => Matrix a -> Matrix a
warshall a0 = loop r a0 where
  b@((r,c),(r',c')) = bounds a0 -- assuming r == c and r' == c'
  loop k a | k <= r' =
    loop (k+1) (array b [ ((i,j),
                           (a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j))))
                        | i <- [r..r'], j <- [c..c'] ])
           | otherwise = a

type AdjList node edge = Map node [(node, edge)]

-- | Warshall's algorithm on a graph represented as an adjacency list.
warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge
warshallG g = fromMatrix $ warshall m
  where
    nodes = zip (nub $ Map.keys g ++ map fst (concat $ Map.elems g))
                [0..]
    len   = length nodes
    b     = ((0,0), (len - 1,len - 1))

    edge i j = do
      es <- Map.lookup i g
      foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ]

    m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ]

    fromMatrix matrix = Map.fromList $ do
      (i, n) <- nodes
      let es = [ (fst (nodes !! m), e)
               | m <- [0..len - 1]
               , Just e <- [matrix ! (n, m)]
               ]
      return (i, es)

-- | Edge weight in the graph, forming a semi ring.
data Weight = Finite Int | Infinite
              deriving (Eq)

inc :: Weight -> Int -> Weight
inc Infinite   n = Infinite
inc (Finite k) n = Finite (k + n)

instance Show Weight where
  show (Finite i) = show i
  show Infinite   = "."

instance Ord Weight where
  a <= Infinite = True
  Infinite <= b = False
  Finite a <= Finite b = a <= b

instance SemiRing Weight where
  oplus = min

  otimes Infinite _ = Infinite
  otimes _ Infinite = Infinite
  otimes (Finite a) (Finite b) = Finite (a + b)

-- constraints ---------------------------------------------------

-- | Nodes of the graph are either
-- - flexible variables (with identifiers drawn from @Int@),
-- - rigid variables (also identified by @Int@s), or
-- - constants (like 0, infinity, or anything between).

data Node = Rigid Rigid
          | Flex  FlexId
            deriving (Eq, Ord)

data Rigid = RConst Weight
           | RVar RigidId
             deriving (Eq, Ord, Show)

type NodeId  = Int
type RigidId = Int
type FlexId  = Int
type Scope   = RigidId -> Bool
-- ^ Which rigid variables a flex may be instatiated to.

instance Show Node where
  show (Flex  i) = "?" ++ show i
  show (Rigid (RVar i)) = "v" ++ show i
  show (Rigid (RConst Infinite))   = "#"
  show (Rigid (RConst (Finite n))) = show n

infinite :: Rigid -> Bool
infinite (RConst Infinite) = True
infinite _                 = False

-- | @isBelow r w r'@
--   checks, if @r@ and @r'@ are connected by @w@ (meaning @w@ not infinite),
--   whether @r + w <= r'@.
--   Precondition: not the same rigid variable.
isBelow :: Rigid -> Weight -> Rigid -> Bool
isBelow _ Infinite _ = True
isBelow _ n (RConst Infinite) = True
isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j
isBelow _ _ _ = False -- rigid variables are not related

-- | A constraint is an edge in the graph.
data Constraint
  = NewFlex FlexId Scope
  | Arc Node Int Node
    -- ^ For @Arc v1 k v2@  at least one of @v1@ or @v2@ is a @MetaV@ (Flex),
    --                      the other a @MetaV@ or a @Var@ (Rigid).
    --   If @k <= 0@ this means  @suc^(-k) v1 <= v2@
    --   otherwise               @v1 <= suc^k v3@.

instance Show Constraint where
  show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")"
  show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2
                     | k < 0  = show v1 ++ "+" ++ show (-k) ++ "<=" ++ show v2
                     | otherwise  = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k

type Constraints = [Constraint]

emptyConstraints :: Constraints
emptyConstraints = []

-- graph (matrix) ------------------------------------------------

data Graph = Graph
  { flexScope :: Map FlexId Scope           -- ^ Scope for each flexible var.
  , nodeMap   :: Map Node NodeId            -- ^ Node labels to node numbers.
  , intMap    :: Map NodeId Node            -- ^ Node numbers to node labels.
  , nextNode  :: NodeId                     -- ^ Number of nodes @n@.
  , graph     :: NodeId -> NodeId -> Weight -- ^ The edges (restrict to @[0..n[@).
  }

-- | The empty graph: no nodes, edges are all undefined (infinity weight).
initGraph :: Graph
initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite)

-- | The Graph Monad, for constructing a graph iteratively.
type GM = State Graph

-- | Add a size meta node.
addFlex :: FlexId -> Scope -> GM ()
addFlex x scope = do
  modify $ \ st -> st { flexScope = Map.insert x scope (flexScope st) }
  addNode (Flex x)
  return ()


-- | Lookup identifier of a node.
--   If not present, it is added first.
addNode :: Node -> GM Int
addNode n = do
  st <- get
  case Map.lookup n (nodeMap st) of
    Just i -> return i
    Nothing -> do
      let i = nextNode st
      put $ st { nodeMap  = Map.insert n i (nodeMap st)
               , intMap   = Map.insert i n (intMap st)
               , nextNode = i + 1
               }
      return i

-- | @addEdge n1 k n2@
--   improves the weight of egde @n1->n2@ to be at most @k@.
--   Also adds nodes if not yet present.
addEdge :: Node -> Int -> Node -> GM ()
addEdge n1 k n2 = do
  i1 <- addNode n1
  i2 <- addNode n2
  st <- get
  let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y
                   else graph st x y
  put $ st { graph = graph' }

addConstraint :: Constraint -> GM ()
addConstraint (NewFlex x scope) = addFlex x scope
addConstraint (Arc n1 k n2)     = addEdge n1 k n2

buildGraph :: Constraints -> Graph
buildGraph cs = execState (mapM_ addConstraint cs) initGraph

mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight
mkMatrix n g = array ((0,0),(n-1,n-1))
                 [ ((i,j), g i j) | i <- [0..n-1], j <- [0..n-1]]

-- displaying matrices with row and column labels --------------------

-- | A matrix with row descriptions in @b@ and column descriptions in @c@.
data LegendMatrix a b c = LegendMatrix
  { matrix   :: Matrix a
  , rowdescr :: Int -> b
  , coldescr :: Int -> c
  }

instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where
  show (LegendMatrix m rd cd) =
    -- first show column description
    let ((r,c),(r',c')) = bounds m
    in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++
    -- then output rows
       foldr (\ i s -> "\n" ++ show (rd i) ++
                foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t)
                      (s)
                      [c .. c'])
             "" [r .. r']

-- solving the constraints -------------------------------------------

-- | A solution assigns to each flexible variable a size expression
--   which is either a constant or a @v + n@ for a rigid variable @v@.
type Solution = Map Int SizeExpr

emptySolution :: Solution
emptySolution = Map.empty

extendSolution :: Solution -> Int -> SizeExpr -> Solution
extendSolution subst k v = Map.insert k v subst

data SizeExpr = SizeVar RigidId Int   -- ^ e.g. x + 5
              | SizeConst Weight      -- ^ a number or infinity

instance Show SizeExpr where
  show (SizeVar n 0) = show (Rigid (RVar n))
  show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k
  show (SizeConst w) = show w

-- | @sizeRigid r n@ returns the size expression corresponding to @r + n@
sizeRigid :: Rigid -> Int -> SizeExpr
sizeRigid (RConst k) n = SizeConst (inc k n)
sizeRigid (RVar i)   n = SizeVar i n

{-
apply :: SizeExpr -> Solution -> SizeExpr
apply e@(SizeExpr (Rigid _) _) phi = e
apply e@(SizeExpr (Flex  x) i) phi = case Map.lookup x phi of
  Nothing -> e
  Just (SizeExpr v j) -> SizeExpr v (i + j)

after :: Solution -> Solution -> Solution
after psi phi = Map.map (\ e -> e `apply` phi) psi
-}

{- compute solution

a solution CANNOT exist if

  v < v  for a rigid variable v

  -- Andreas, 2012-09-19 OUTDATED are:

  -- v <= v' for rigid variables v,v'

  -- x < v   for a flexible variable x and a rigid variable v


thus, for each flexible x, only one of the following cases is possible

  r+n <= x+m <= infty  for a unique rigid r  (meaning r --(m-n)--> x)
  x <= r+n             for a unique rigid r  (meaning x --(n)--> r)

we are looking for the least values for flexible variables that solve
the constraints.  Algorithm

while flexible variables and rigid rows left
  find a rigid variable row i
    for all flexible columns j
      if i --n--> j with n<=0 (meaning i+n <= j) then j = i + n

while flexible variables j left
  search the row j for entry i
    if j --n--> i with n >= 0 (meaning j <= i + n) then j = i + n

-}

solve :: Constraints -> Maybe Solution
solve cs = -- trace (show cs) $
   -- trace (show lm0) $
    -- trace (show lm) $ -- trace (show d) $
     let solution = if solvable then loop1 flexs rigids emptySolution
                    else Nothing
     in -- trace (show solution) $
         solution
   where -- compute the graph and its transitive closure m
         gr  = buildGraph cs
         n   = nextNode gr            -- number of nodes
         m0  = mkMatrix n (graph gr)
         m   = warshall m0

         -- tracing only: build output version of transitive graph
         legend i = fromJust $ Map.lookup i (intMap gr) -- trace only
         lm0 = LegendMatrix m0 legend legend            -- trace only
         lm  = LegendMatrix m legend legend             -- trace only

         -- compute the sets of flexible and rigid node numbers
         ns  = Map.keys (nodeMap gr)
         -- a set of flexible variables
         flexs  = foldl' (\ l k -> case k of (Flex i) -> i : l
                                             (Rigid _) -> l) [] ns
         -- a set of rigid variables
         rigids = foldl' (\ l k -> case k of (Flex _) -> l
                                             (Rigid i) -> i : l) [] ns

         -- rigid matrix indices
         rInds = foldl' (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr)
                                  in i : l) [] rigids

         -- check whether there is a solution
         -- d   = [ m!(i,i) | i <- [0 .. (n-1)] ]  -- diagonal
-- a rigid variable might not be less than it self, so no -.. on the
-- rigid part of the diagonal
         solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] && True

{-  Andreas, 2012-09-19
    We now can have constraints between rigid variables, like i < j.
    Thus we skip the following two test.  However, a solution must be
    checked for consistency with the constraints on rigid vars.

-- a rigid variable might not be bounded below by infinity or
-- bounded above by a constant
-- it might not be related to another rigid variable
           all (\ (r,  r') -> r == r' ||
                let Just row = (Map.lookup (Rigid r)  (nodeMap gr))
                    Just col = (Map.lookup (Rigid r') (nodeMap gr))
                    edge = m!(row,col)
                in  isBelow r edge r' )
             [ (r,r') | r <- rigids, r' <- rigids ]
           &&
-- a flexible variable might not be strictly below a rigid variable
           all (\ (x, v) ->
                let Just row = (Map.lookup (Flex x)  (nodeMap gr))
                    Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr))
                    edge = m!(row,col)
                in  edge >= Finite 0)
             [ (x,v) | x <- flexs, (RVar v) <- rigids ]
-}

         inScope :: FlexId -> Rigid -> Bool
         inScope x (RConst _) = True
         inScope x (RVar v)   = scope v
             where Just scope = Map.lookup x (flexScope gr)

{- loop1

while flexible variables and rigid rows left
  find a rigid variable row i
    for all flexible columns j
      if i --n--> j with n<=0 (meaning i + n <= j) then j = i + n

-}
         loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution
         loop1 [] rgds subst = Just subst
         loop1 flxs [] subst = loop2 flxs subst
         loop1 flxs (r:rgds) subst =
            let row = fromJust $ Map.lookup (Rigid r) (nodeMap gr)
                (flxs',subst') =
                  foldl' (\ (flx,sub) f ->
                          let col = fromJust $ Map.lookup (Flex f) (nodeMap gr)
                          in  case (inScope f r, m!(row,col)) of
--                                Finite z | z <= 0 ->
                                (True, Finite z) ->
                                   let trunc z | z >= 0 = 0
                                            | otherwise = -z
                                   in (flx, extendSolution sub f (sizeRigid r (trunc z)))
                                _ -> (f : flx, sub)
                     ) ([], subst) flxs
            in loop1 flxs' rgds subst'

{- loop2

while flexible variables j left
  search the row j for entry i
    if j --n--> i with n >= 0 (meaning j <= i + n) then j = i

-}
         loop2 :: [FlexId] -> Solution -> Maybe Solution
         loop2 [] subst = Just subst
         loop2 (f:flxs) subst = loop3 0 subst
           where row = fromJust $ Map.lookup (Flex f) (nodeMap gr)
                 loop3 col subst | col >= n =
                   -- default to infinity
                    loop2 flxs (extendSolution subst f (SizeConst Infinite))
                 loop3 col subst =
                   case Map.lookup col (intMap gr) of
                     Just (Rigid r) | not (infinite r) ->
                       case (inScope f r, m!(row,col)) of
                        (True, Finite z) | z >= 0 ->
                            loop2 flxs (extendSolution subst f (sizeRigid r z))
                        (_, Infinite) -> loop3 (col+1) subst
                        _ -> -- trace ("unusable rigid: " ++ show r ++ " for flex " ++ show f)
                              Nothing  -- NOT: loop3 (col+1) subst
                     _ -> loop3 (col+1) subst



-- Testing ----------------------------------------------------------------

genGraph :: Ord node => Float -> Gen edge -> [node] -> Gen (AdjList node edge)
genGraph density edge nodes = do
  Map.fromList . concat <$> mapM neighbours nodes
  where
    k = round (100 * density)
    neighbours n = do
      ns <- concat <$> mapM neighbour nodes
      case ns of
        []  -> elements [[(n, [])], []]
        _   -> return [(n, ns)]
    neighbour n = frequency
      [ (k, do e <- edge
               ns <- neighbour n
               return ((n, e):ns))
      , (100 - k, return [])
      ]

newtype Distance = Dist Nat
  deriving (Eq, Ord, Num, Integral, Show, Enum, Real)

instance SemiRing Distance where
  oplus  (Dist a) (Dist b) = Dist (min a b)
  otimes (Dist a) (Dist b) = Dist (a + b)

genGraph_ :: Nat -> Gen (AdjList Nat Distance)
genGraph_ n =
  genGraph 0.2 (Dist <$> natural) [0..n - 1]

lookupEdge :: Ord n => n -> n -> AdjList n e -> Maybe e
lookupEdge i j g = lookup j =<< Map.lookup i g

edges :: Ord n => AdjList n e -> [(n,n,e)]
edges g = do
  (i, ns) <- Map.toList g
  (j, e)  <- ns
  return (i, j, e)

-- | Check that no edges get longer when completing a graph.
prop_smaller :: Nat -> Property
prop_smaller n' =
  forAll (genGraph_ n) $ \g ->
  let g' = warshallG g in
  and [ lookupEdge i j g' =< e
      | (i, j, e) <- edges g
      ]
  where
    n = abs (div n' 2)
    Nothing =< _ = False
    Just x  =< y = x <= y

newEdge :: Nat -> Nat -> Distance -> AdjList Nat Distance ->
           AdjList Nat Distance
newEdge i j e = Map.insertWith (++) i [(j, e)]

genPath :: Nat -> Nat -> Nat -> AdjList Nat Distance ->
           Gen (AdjList Nat Distance)
genPath n i j g = do
  es <- listOf $ (,) <$> node <*> edge
  v  <- edge
  return $ addPath i (es ++ [(j, v)]) g
  where
    edge :: Gen Distance
    edge = Dist <$> natural

    node :: Gen Nat
    node = choose (0, n - 1)

    addPath :: Nat -> [(Nat, Distance)] -> AdjList Nat Distance ->
               AdjList Nat Distance
    addPath _ []          g = g
    addPath i ((j, v):es) g = newEdge i j v $ addPath j es g

-- | Check that all transitive edges are added.
prop_path :: Nat -> Property
prop_path n' =
  forAll (genGraph_ n) $ \g ->
  forAll (two $ choose (0, n - 1)) $ \(i, j) ->
  forAll (genPath n i j g) $ \g' ->
  isJust (lookupEdge i j $ warshallG g')
  where
    n = abs (div n' 2) + 1

mapNodes :: (Ord node, Ord node') => (node -> node') -> AdjList node edge -> AdjList node' edge
mapNodes f = Map.map f' . Map.mapKeys f
  where
    f' es = [ (f n, e) | (n,e) <- es ]

-- | Check that no edges are added between components.
prop_disjoint :: Nat -> Property
prop_disjoint n' =
  forAll (two $ genGraph_ n) $ \(g1, g2) ->
  let g  = Map.union (mapNodes Left g1) (mapNodes Right g2)
      g' = warshallG g
  in all disjoint (Map.assocs g')
  where
    n = abs (div n' 3)
    disjoint (Left i, es)  = all (isLeft . fst) es
    disjoint (Right i, es) = all (isRight . fst) es
    isLeft = either (const True) (const False)
    isRight = not . isLeft

prop_stable :: Nat -> Property
prop_stable n' =
  forAll (genGraph_ n) $ \g ->
  let g' = warshallG g in
  g' =~= warshallG g'
  where
    n = abs (div n' 2)
    g =~= g' = sort (edges g) == sort (edges g')

tests :: IO Bool
tests = runTests "Agda.Utils.Warshall" []