module Data.Graph.PageRank (pageRanks, RankDic(..)) where
import Control.Monad.RWS (RWS(..), asks, gets, execRWS, get, put)
import Data.Graph.Inductive
import Prelude hiding (map, lookup)
import Data.Map (Map, lookup, fromList, foldWithKey, findWithDefault)
import Data.Maybe (fromJust)
import Control.Monad

map :: (Functor f) => (a -> b) -> f a -> f b
map = fmap

data Env = Env {node :: [Node], from :: Map Node [Node], outdegrees :: Map Node Int}

-- |'RankDic' is the Map for holding PageRank data.
type RankDic = Map Node Double 
type PRMachine = RWS Env () RankDic

lookupEnv :: (Ord a) => (Env -> Map a b) -> a -> PRMachine b
lookupEnv f a = do{ dic<-asks f; return $ fromJust $ lookup a dic}

outdegree :: Node -> PRMachine Int
outdegree = lookupEnv outdegrees

froms :: Node -> PRMachine [Node]
froms = lookupEnv from

currentRank :: Node -> PRMachine Double
currentRank nd = gets (fromJust.lookup nd)

-- |'pageRanks' calculate the PageRank for each node in the Graph 'gr'
pageRanks :: (Graph gr) => gr a b -> Double -> Double -> RankDic
pageRanks gr epsilon error = fst $ execRWS steps Env{node=nds, from=froms, outdegrees=outdegs} initRanks
    where nds = nodes gr
          count :: (Num a) => a
          count = fromIntegral $ noNodes gr
          froms = fromList $ zip nds $ map (pre gr) nds
          outdegs = fromList $ zip nds $ map (outdeg gr) nds
          initRanks = fromList $ zip nds $ replicate count (1/count)
          steps = do
            old <- get
            new <- calcPageRank epsilon
            let cond = foldWithKey (\k a b -> b && ((findWithDefault (1/0) k new)-a < error)) True old
            if cond then return new else steps
            
            

calcPageRank :: Double -> PRMachine RankDic
calcPageRank epsilon = do
  nds <- asks node
  dic <- forM nds $ \n -> do
                 frms <- froms n
                 ranks <- forM frms $ \m -> do
                            deg <- outdegree m
                            rank <- currentRank m
                            return (rank/fromIntegral deg)
                 count <- liftM (fromIntegral.length) $ asks node
                 return (n, epsilon/count + (1-epsilon)*(sum ranks))
  let rdic = fromList dic
  put rdic
  return rdic