{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{- | Construct a graph from constraints
@
   x + n <= y   becomes   x ---(-n)---> y
   x <= n + y   becomes   x ---(+n)---> y
@
the default edge (= no edge) is labelled with infinity.

Building the graph involves keeping track of the node names.
We do this in a finite map, assigning consecutive numbers to nodes.
-}
module Agda.Utils.Warshall where

import Control.Monad.State

import Data.Maybe
import Data.Array
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map

import Agda.Syntax.Common (Nat)
import Agda.Utils.SemiRing

type Matrix a = Array (Int,Int) a

-- assuming a square matrix
warshall :: SemiRing a => Matrix a -> Matrix a
warshall a0 = loop r a0 where
  b@((r,c),(r',c')) = bounds a0 -- assuming r == c and r' == c'
  loop k a | k <= r' =
    loop (k+1) (array b [ ((i,j),
                           (a!(i,j)) `oplus` ((a!(i,k)) `otimes` (a!(k,j))))
                        | i <- [r..r'], j <- [c..c'] ])
           | otherwise = a

type AdjList node edge = Map node [(node, edge)]

-- | Warshall's algorithm on a graph represented as an adjacency list.
warshallG :: (SemiRing edge, Ord node) => AdjList node edge -> AdjList node edge
warshallG g = fromMatrix $ warshall m
  where
    nodes = zip (List.nub $ Map.keys g ++ map fst (concat $ Map.elems g))
                [0..]
    len   = length nodes
    b     = ((0,0), (len - 1,len - 1))

    edge i j = do
      es <- Map.lookup i g
      foldr oplus Nothing [ Just v | (j', v) <- es, j == j' ]

    m = array b [ ((n, m), edge i j) | (i, n) <- nodes, (j, m) <- nodes ]

    fromMatrix matrix = Map.fromList $ do
      (i, n) <- nodes
      let es = [ (fst (nodes !! m), e)
               | m <- [0..len - 1]
               , Just e <- [matrix ! (n, m)]
               ]
      return (i, es)

-- | Edge weight in the graph, forming a semi ring.
data Weight = Finite Int | Infinite
              deriving (Eq)

inc :: Weight -> Int -> Weight
inc Infinite   n = Infinite
inc (Finite k) n = Finite (k + n)

instance Show Weight where
  show (Finite i) = show i
  show Infinite   = "."

instance Ord Weight where
  a <= Infinite = True
  Infinite <= b = False
  Finite a <= Finite b = a <= b

instance SemiRing Weight where
  ozero = Infinite
  oone  = Finite 0

  oplus = min

  otimes Infinite _ = Infinite
  otimes _ Infinite = Infinite
  otimes (Finite a) (Finite b) = Finite (a + b)

-- constraints ---------------------------------------------------

-- | Nodes of the graph are either
-- - flexible variables (with identifiers drawn from @Int@),
-- - rigid variables (also identified by @Int@s), or
-- - constants (like 0, infinity, or anything between).

data Node = Rigid Rigid
          | Flex  FlexId
            deriving (Eq, Ord)

data Rigid = RConst Weight
           | RVar RigidId
             deriving (Eq, Ord, Show)

type NodeId  = Int
type RigidId = Int
type FlexId  = Int
type Scope   = RigidId -> Bool
-- ^ Which rigid variables a flex may be instatiated to.

instance Show Node where
  show (Flex  i) = "?" ++ show i
  show (Rigid (RVar i)) = "v" ++ show i
  show (Rigid (RConst Infinite))   = "#"
  show (Rigid (RConst (Finite n))) = show n

infinite :: Rigid -> Bool
infinite (RConst Infinite) = True
infinite _                 = False

-- | @isBelow r w r'@
--   checks, if @r@ and @r'@ are connected by @w@ (meaning @w@ not infinite),
--   whether @r + w <= r'@.
--   Precondition: not the same rigid variable.
isBelow :: Rigid -> Weight -> Rigid -> Bool
isBelow _ Infinite _ = True
isBelow _ n (RConst Infinite) = True
isBelow (RConst (Finite i)) (Finite n) (RConst (Finite j)) = i + n <= j
isBelow _ _ _ = False -- rigid variables are not related

-- | A constraint is an edge in the graph.
data Constraint
  = NewFlex FlexId Scope
  | Arc Node Int Node
    -- ^ For @Arc v1 k v2@  at least one of @v1@ or @v2@ is a @MetaV@ (Flex),
    --                      the other a @MetaV@ or a @Var@ (Rigid).
    --   If @k <= 0@ this means  @suc^(-k) v1 <= v2@
    --   otherwise               @v1 <= suc^k v3@.

instance Show Constraint where
  show (NewFlex i s) = "SizeMeta(?" ++ show i ++ ")"
  show (Arc v1 k v2) | k == 0 = show v1 ++ "<=" ++ show v2
                     | k < 0  = show v1 ++ "+" ++ show (-k) ++ "<=" ++ show v2
                     | otherwise  = show v1 ++ "<=" ++ show v2 ++ "+" ++ show k

type Constraints = [Constraint]

emptyConstraints :: Constraints
emptyConstraints = []

-- graph (matrix) ------------------------------------------------

data Graph = Graph
  { flexScope :: Map FlexId Scope           -- ^ Scope for each flexible var.
  , nodeMap   :: Map Node NodeId            -- ^ Node labels to node numbers.
  , intMap    :: Map NodeId Node            -- ^ Node numbers to node labels.
  , nextNode  :: NodeId                     -- ^ Number of nodes @n@.
  , graph     :: NodeId -> NodeId -> Weight -- ^ The edges (restrict to @[0..n[@).
  }

-- | The empty graph: no nodes, edges are all undefined (infinity weight).
initGraph :: Graph
initGraph = Graph Map.empty Map.empty Map.empty 0 (\ x y -> Infinite)

-- | The Graph Monad, for constructing a graph iteratively.
type GM = State Graph

-- | Add a size meta node.
addFlex :: FlexId -> Scope -> GM ()
addFlex x scope = do
  modify $ \ st -> st { flexScope = Map.insert x scope (flexScope st) }
  _ <- addNode (Flex x)
  return ()

-- | Lookup identifier of a node.
--   If not present, it is added first.
addNode :: Node -> GM Int
addNode n = do
  st <- get
  case Map.lookup n (nodeMap st) of
    Just i -> return i
    Nothing -> do
      let i = nextNode st
      put $ st { nodeMap  = Map.insert n i (nodeMap st)
               , intMap   = Map.insert i n (intMap st)
               , nextNode = i + 1
               }
      return i

-- | @addEdge n1 k n2@
--   improves the weight of egde @n1->n2@ to be at most @k@.
--   Also adds nodes if not yet present.
addEdge :: Node -> Int -> Node -> GM ()
addEdge n1 k n2 = do
  i1 <- addNode n1
  i2 <- addNode n2
  st <- get
  let graph' x y = if (x,y) == (i1,i2) then Finite k `oplus` (graph st) x y
                   else graph st x y
  put $ st { graph = graph' }

addConstraint :: Constraint -> GM ()
addConstraint (NewFlex x scope) = addFlex x scope
addConstraint (Arc n1 k n2)     = addEdge n1 k n2

buildGraph :: Constraints -> Graph
buildGraph cs = execState (mapM_ addConstraint cs) initGraph

mkMatrix :: Int -> (Int -> Int -> Weight) -> Matrix Weight
mkMatrix n g = array ((0,0),(n-1,n-1))
                 [ ((i,j), g i j) | i <- [0..n-1], j <- [0..n-1]]

-- displaying matrices with row and column labels --------------------

-- | A matrix with row descriptions in @b@ and column descriptions in @c@.
data LegendMatrix a b c = LegendMatrix
  { matrix   :: Matrix a
  , rowdescr :: Int -> b
  , coldescr :: Int -> c
  }

instance (Show a, Show b, Show c) => Show (LegendMatrix a b c) where
  show (LegendMatrix m rd cd) =
    -- first show column description
    let ((r,c),(r',c')) = bounds m
    in foldr (\ j s -> "\t" ++ show (cd j) ++ s) "" [c .. c'] ++
    -- then output rows
       foldr (\ i s -> "\n" ++ show (rd i) ++
                foldr (\ j t -> "\t" ++ show (m!(i,j)) ++ t)
                      (s)
                      [c .. c'])
             "" [r .. r']

-- solving the constraints -------------------------------------------

-- | A solution assigns to each flexible variable a size expression
--   which is either a constant or a @v + n@ for a rigid variable @v@.
type Solution = Map Int SizeExpr

emptySolution :: Solution
emptySolution = Map.empty

extendSolution :: Solution -> Int -> SizeExpr -> Solution
extendSolution subst k v = Map.insert k v subst

data SizeExpr = SizeVar RigidId Int   -- ^ e.g. x + 5
              | SizeConst Weight      -- ^ a number or infinity

instance Show SizeExpr where
  show (SizeVar n 0) = show (Rigid (RVar n))
  show (SizeVar n k) = show (Rigid (RVar n)) ++ "+" ++ show k
  show (SizeConst w) = show w

-- | @sizeRigid r n@ returns the size expression corresponding to @r + n@
sizeRigid :: Rigid -> Int -> SizeExpr
sizeRigid (RConst k) n = SizeConst (inc k n)
sizeRigid (RVar i)   n = SizeVar i n

{-
apply :: SizeExpr -> Solution -> SizeExpr
apply e@(SizeExpr (Rigid _) _) phi = e
apply e@(SizeExpr (Flex  x) i) phi = case Map.lookup x phi of
  Nothing -> e
  Just (SizeExpr v j) -> SizeExpr v (i + j)

after :: Solution -> Solution -> Solution
after psi phi = Map.map (\ e -> e `apply` phi) psi
-}

{- compute solution

a solution CANNOT exist if

  v < v  for a rigid variable v

  -- Andreas, 2012-09-19 OUTDATED are:

  -- v <= v' for rigid variables v,v'

  -- x < v   for a flexible variable x and a rigid variable v


thus, for each flexible x, only one of the following cases is possible

  r+n <= x+m <= infty  for a unique rigid r  (meaning r --(m-n)--> x)
  x <= r+n             for a unique rigid r  (meaning x --(n)--> r)

we are looking for the least values for flexible variables that solve
the constraints.  Algorithm

while flexible variables and rigid rows left
  find a rigid variable row i
    for all flexible columns j
      if i --n--> j with n<=0 (meaning i+n <= j) then j = i + n

while flexible variables j left
  search the row j for entry i
    if j --n--> i with n >= 0 (meaning j <= i + n) then j = i + n

-}

solve :: Constraints -> Maybe Solution
solve cs = -- trace (show cs) $
   -- trace (show lm0) $
    -- trace (show lm) $ -- trace (show d) $
     let solution = if solvable then loop1 flexs rigids emptySolution
                    else Nothing
     in -- trace (show solution) $
         solution
   where -- compute the graph and its transitive closure m
         gr  = buildGraph cs
         n   = nextNode gr            -- number of nodes
         m0  = mkMatrix n (graph gr)
         m   = warshall m0

         -- tracing only: build output version of transitive graph
         legend i = fromJust $ Map.lookup i (intMap gr) -- trace only
         lm0 = LegendMatrix m0 legend legend            -- trace only
         lm  = LegendMatrix m legend legend             -- trace only

         -- compute the sets of flexible and rigid node numbers
         ns  = Map.keys (nodeMap gr)
         -- a set of flexible variables
         flexs  = List.foldl' (\ l -> \case (Flex i ) -> i : l
                                            (Rigid _) -> l)     [] ns
         -- a set of rigid variables
         rigids = List.foldl' (\ l -> \case (Flex _ ) -> l
                                            (Rigid i) -> i : l) [] ns

         -- rigid matrix indices
         rInds = List.foldl' (\ l r -> let Just i = Map.lookup (Rigid r) (nodeMap gr)
                                       in i : l) [] rigids

         -- check whether there is a solution
         -- d   = [ m!(i,i) | i <- [0 .. (n-1)] ]  -- diagonal
-- a rigid variable might not be less than it self, so no -.. on the
-- rigid part of the diagonal
         solvable = all (\ x -> x >= Finite 0) [ m!(i,i) | i <- rInds ] && True

{-  Andreas, 2012-09-19
    We now can have constraints between rigid variables, like i < j.
    Thus we skip the following two test.  However, a solution must be
    checked for consistency with the constraints on rigid vars.

-- a rigid variable might not be bounded below by infinity or
-- bounded above by a constant
-- it might not be related to another rigid variable
           all (\ (r,  r') -> r == r' ||
                let Just row = (Map.lookup (Rigid r)  (nodeMap gr))
                    Just col = (Map.lookup (Rigid r') (nodeMap gr))
                    edge = m!(row,col)
                in  isBelow r edge r' )
             [ (r,r') | r <- rigids, r' <- rigids ]
           &&
-- a flexible variable might not be strictly below a rigid variable
           all (\ (x, v) ->
                let Just row = (Map.lookup (Flex x)  (nodeMap gr))
                    Just col = (Map.lookup (Rigid (RVar v)) (nodeMap gr))
                    edge = m!(row,col)
                in  edge >= Finite 0)
             [ (x,v) | x <- flexs, (RVar v) <- rigids ]
-}

         inScope :: FlexId -> Rigid -> Bool
         inScope x (RConst _) = True
         inScope x (RVar v)   = scope v
             where Just scope = Map.lookup x (flexScope gr)

{- loop1

while flexible variables and rigid rows left
  find a rigid variable row i
    for all flexible columns j
      if i --n--> j with n<=0 (meaning i + n <= j) then j = i + n

-}
         loop1 :: [FlexId] -> [Rigid] -> Solution -> Maybe Solution
         loop1 [] rgds subst = Just subst
         loop1 flxs [] subst = loop2 flxs subst
         loop1 flxs (r:rgds) subst =
            let row = fromJust $ Map.lookup (Rigid r) (nodeMap gr)
                (flxs',subst') =
                  List.foldl' (\ (flx,sub) f ->
                          let col = fromJust $ Map.lookup (Flex f) (nodeMap gr)
                          in  case (inScope f r, m!(row,col)) of
--                                Finite z | z <= 0 ->
                                (True, Finite z) ->
                                   let trunc z | z >= 0 = 0
                                            | otherwise = -z
                                   in (flx, extendSolution sub f (sizeRigid r (trunc z)))
                                _ -> (f : flx, sub)
                     ) ([], subst) flxs
            in loop1 flxs' rgds subst'

{- loop2

while flexible variables j left
  search the row j for entry i
    if j --n--> i with n >= 0 (meaning j <= i + n) then j = i

-}
         loop2 :: [FlexId] -> Solution -> Maybe Solution
         loop2 [] subst = Just subst
         loop2 (f:flxs) subst = loop3 0 subst
           where row = fromJust $ Map.lookup (Flex f) (nodeMap gr)
                 loop3 col subst | col >= n =
                   -- default to infinity
                    loop2 flxs (extendSolution subst f (SizeConst Infinite))
                 loop3 col subst =
                   case Map.lookup col (intMap gr) of
                     Just (Rigid r) | not (infinite r) ->
                       case (inScope f r, m!(row,col)) of
                        (True, Finite z) | z >= 0 ->
                            loop2 flxs (extendSolution subst f (sizeRigid r z))
                        (_, Infinite) -> loop3 (col+1) subst
                        _ -> -- trace ("unusable rigid: " ++ show r ++ " for flex " ++ show f)
                              Nothing  -- NOT: loop3 (col+1) subst
                     _ -> loop3 (col+1) subst