-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.TBox.TSkipList
-- Copyright   :  Peter Robinson 2010-2012
-- License     :  LGPL
-- 
-- Maintainer  :  Peter Robinson <thaldyron@gmail.com>
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- This module provides an implementation of a skip list in the 'STM' monad.
-- The elements of the skip list are stored in a 'TVar'.
--
-- A skip list is a probabilistic data structure with dictionary operations
-- (similar to Data.Map).
-- In contrast to a balanced tree, a skip list does not need any 
-- (expensive) rebalancing operation, which makes it particularly suitable 
-- for concurrent programming. 
-- 
-- See: /William Pugh. Skip Lists: A Probabilistic Alternative to Balanced Trees./
--
-- This module should be imported qualified.
--
-- /Example (GHCi):/ 
--
-- > t <- newIO 0.5 5  :: IO (TSkipList Int String) 
-- > atomically $ sequence_ [ insert i (show i) t | i <- [1..10] ]
-- >
-- > putStrLn =<< atomically (toString 100 t)
-- > 9
-- > 9
-- > 3 7 9
-- > 1 3 7 9
-- > 1 2 3 4 5 6 7 8 9 10
-- >
-- > atomically $ delete  7 t
-- > putStrLn =<< atomically (toString 100 t)
-- > 9
-- > 9
-- > 3 9
-- > 1 3 9
-- > 1 2 3 4 5 6 8 9 10
-- > 
-- > atomically $ sequence [ lookup i t | i <- [5..10] ]
-- > [Just "5",Just "6",Nothing,Just "8",Just "9",Just "10"]
-- >
-- > atomically $ update 8 "X" t
-- > atomically $ sequence [ lookup i t | i <- [5..10] ]
-- > [Just "5",Just "6",Nothing,Just "X",Just "9",Just "10"]
-----------------------------------------------------------------------------
module Control.Concurrent.STM.TSkipList(-- * Data type and Construction
                                         TSkipList,newIO,new,
                                         -- * Operations
                                         insert,lookup,update,delete,geq,leq,filter,
                                         -- * Utilities 
                                         chooseLevel,
                                         toString,
                                       ) 
where
import Control.Concurrent.STM
import GHC.Conc
import Control.Applicative
import Control.Monad
import Control.Exception

import System.Random
import Data.Array.MArray
import Data.Map(Map)
import qualified Data.Map as M
import Prelude hiding(filter,lookup)

type ForwardPtrs k a = TArray Int (Node k a)


data TSkipList k a = TSkipList 
  { maxLevel    :: Int
  , probability :: Float
  , curLevel    :: TVar Int
  , listHead    :: ForwardPtrs k a
  }

data Node k a
  = Nil 
  | Node { key          :: k 
         , contentTVar  :: TVar a 
         , forwardPtrs  :: ForwardPtrs k a
         }

{-
newNode :: k -> TVar a -> Int -> STM (Node k a)
newNode k t maxLvl = Node k t `liftM` (newForwardPtrs maxLvl)
-}

isNil :: Node k a -> Bool
isNil Nil = True
isNil _   = False

-- | An empty skiplist that uses the standard random generator.
newIO :: Float  -- ^ Probability for choosing a new level
      -> Int    -- ^ Maximum number of levels
      -> IO (TSkipList k a)
newIO p maxLvl = atomically $ new p maxLvl

-- | An empty skiplist.
new :: Float  -- ^ Probability for choosing a new level
    -> Int    -- ^ Maximum number of levels
    -> STM (TSkipList k a)
new p maxLvl = 
  TSkipList maxLvl p `liftM` newTVar 1 
                        `ap` newForwardPtrs maxLvl


newForwardPtrs :: Int -> STM (ForwardPtrs k a)
newForwardPtrs maxLvl = newListArray (1,maxLvl) $ replicate maxLvl Nil


-- | Returns a randomly chosen level. Is used for inserting new elements.
-- For performance reasons, this function uses 'unsafeIOToSTM' to access the
-- random number generator. (It would be possible to store the random number
-- generator in a 'TVar' and thus be able to access it safely from within the
-- STM monad. This, however, might cause high contention among threads.
chooseLevel :: TSkipList k a -> STM Int
chooseLevel tskip = do
  stdG <- unsafeIOToSTM newStdGen
  let rs :: StdGen -> [Float]
      rs g = x : rs g' where (x,g') = randomR (0,1) g
  let samples =  take (maxLevel tskip - 1) (rs stdG) 
  return $ 1 + length (takeWhile (probability tskip <) samples) 

{-
chooseLevel tskip = do
  stdG <- unsafeIOToSTM newStdGen
  let rs :: StdGen -> [(Float,StdGen)]
      rs g = (x,g') : rs g' where (x,g') = randomR (0,1) g
  let (samples,newStdGs) = unzip $ take (maxLevel tskip) (rs stdG) 
  return $ 1 + length (takeWhile ((<) (probability tskip)) $ take (maxLevel tskip - 1) samples)
-}


-- | Returns all elements less or equal than the key. 
leq :: (Ord k{- ,Show k -}) => k -> TSkipList k a -> STM (Map k a)
leq k tskip = 
  leqAcc (listHead tskip) 1 M.empty
  where
  leqAcc fwdPtrs lvl curAcc = do
    let moveDown acc _ level  = 
          leqAcc fwdPtrs (level-1) acc 
    let moveRight acc succNode level = 
          addElem acc succNode >>=
            leqAcc (forwardPtrs succNode) level 
    let onFound acc succNode _ = 
          addElem acc succNode
    traverse k fwdPtrs lvl (moveDown curAcc) (moveRight curAcc) (onFound curAcc) (moveDown curAcc) curAcc

  addElem acc succNode = do 
    a <- readTVar (contentTVar succNode)
    return $ M.insert (key succNode) a acc 


-- | Returns all elements greater or equal than the key.
-- TODO: currently in O(n), should be made more efficient (like 'leq')
geq :: (Ord k{- ,Show k -}) => k -> TSkipList k a -> STM (Map k a)
geq k = filter (\k' _ -> (k'>=k))


lookupNode :: (Ord k{- ,Show k -}) => k -> TSkipList k a -> STM (Maybe (Node k a))
lookupNode k tskip = 
  lookupAcc (listHead tskip) =<< readTVar (curLevel tskip)
  where
  lookupAcc fwdPtrs lvl = do
    let moveDown _ level  = lookupAcc fwdPtrs (level-1)
    let moveRight succNode = lookupAcc (forwardPtrs succNode) 
    let onFound succNode _   = return (Just succNode)
    traverse k fwdPtrs lvl moveDown moveRight onFound moveDown Nothing


lookup :: (Ord k{- ,Show k -}) => k -> TSkipList k a -> STM (Maybe a)
lookup k tskip = 
    maybe (return Nothing)
          (\n -> Just <$> readTVar (contentTVar n)) =<< lookupNode k tskip 


-- | Updates an element. Throws 'AssertionFailed' if the element is not in the
-- list.
update :: (Ord k{- ,Show k -}) => k -> a -> TSkipList k a -> STM ()
update k a tskip = 
  maybe (throw $ AssertionFailed "TSkipList.update: element not found!") 
        (flip writeTVar a . contentTVar) =<< lookupNode k tskip 

-- | Deletes an element. Does nothing if the element is not found.
delete :: (Ord k{- ,Show k -}) => k -> TSkipList k a -> STM ()
delete k tskip = 
  deleteAcc (listHead tskip) =<<  readTVar (curLevel tskip)
  where
  deleteAcc fwdPtrs lvl = do
    let moveDown _ level       = deleteAcc fwdPtrs (level-1)
    let moveRight succNode     = deleteAcc (forwardPtrs succNode) 
    let onFound succNode level = do
          succsuccNode <- readArray (forwardPtrs succNode) level 
          writeArray fwdPtrs level succsuccNode
          moveDown succNode level
    traverse k fwdPtrs lvl moveDown moveRight onFound moveDown ()


-- | Inserts/updates the value for a specific key. 
insert :: (Ord k{- ,Show k -}) => k -> a -> TSkipList k a ->  STM ()
insert k a tskip = do
  mNode <- lookupNode k tskip 
  case mNode of
    Just node -> writeTVar (contentTVar node) a
    Nothing   -> do
      tvar    <- newTVar a 
      newPtrs <- newForwardPtrs (maxLevel tskip)
      let node = Node k tvar newPtrs
      insertNode k node tskip
{-
insert :: (Ord k{- ,Show k -}) => k -> a -> TSkipList k a ->  STM ()
insert k a tskip = do
  tvar    <- newTVar a 
  newPtrs <- newForwardPtrs (maxLevel tskip)
  let node = Node k tvar newPtrs
  insertNode k node tskip
-}


insertNode :: (Ord k{- ,Show k -}) => k -> Node k a -> TSkipList k a ->  STM ()
insertNode k node tskip = do
  newLevel <-  chooseLevel tskip
  -- Adapt current maximum level if necesary:
  curLvl   <- readTVar (curLevel tskip)
  when (curLvl < newLevel) $ 
    writeTVar (curLevel tskip) newLevel
  insertAcc (listHead tskip) newLevel
  where
  insertAcc fwdPtrs lvl = do
    let moveDown succNode level = do 
          writeArray (forwardPtrs node) level succNode
          writeArray fwdPtrs level node
          insertAcc fwdPtrs (level-1)
    let moveRight succNode = 
          insertAcc (forwardPtrs succNode) 
    let onFound _ level = do
          writeArray fwdPtrs level node
          insertAcc fwdPtrs (level-1)
    traverse k fwdPtrs lvl moveDown moveRight onFound moveDown ()


traverse :: (Ord k{- ,Show k -}) 
         => k -> ForwardPtrs k a -> Int 
         -> (Node k a -> Int -> STM b)
         -> (Node k a -> Int -> STM b)
         -> (Node k a -> Int -> STM b)
         -> (Node k a -> Int -> STM b)
         -> b
         -> STM b
traverse k fwdPtrs level onLT onGT onFound onNil def
  | level < 1 = return def
  | otherwise = do
    succNode <- readArray fwdPtrs level 
    if isNil succNode 
        then onNil succNode level
        else case k `compare` key succNode of
                 GT -> onGT succNode level
                 LT -> onLT succNode level
                 EQ -> onFound succNode level


-- | Returns all elements that satisfy the predicate. /O(n)/.
filter :: (Ord k{- ,Show k -}) 
      => (k -> a -> Bool) -> TSkipList k a -> STM (Map k a)
filter p tskip = 
  filterAcc (listHead tskip) 1 M.empty
  where
  filterAcc fwdPtrs level acc = do
    succNode <- readArray fwdPtrs level 
    if isNil succNode 
      then return acc
      else do
        newAcc <- addElem acc succNode
        filterAcc (forwardPtrs succNode) level newAcc

  addElem acc succNode = do
    a <- readTVar (contentTVar succNode) 
    return $ if p (key succNode) a 
              then M.insert (key succNode) a acc
              else acc 

-- | Debug helper. Returns the skip list as a string.
-- All elements smaller than the given key are written to the string.
toString :: (Show k,Ord k) => k -> TSkipList k a -> STM String
toString k tskip = do
  curLvl   <- readTVar (curLevel tskip)
  ls <- forM (reverse [1..curLvl]) $ printAcc (listHead tskip) []
  return $ unlines ls
  where
  printAcc fwdPtrs acc curLvl = do
    let moveDown succNode level = 
          if (isNil succNode) 
            then return acc
            else printAcc (forwardPtrs succNode) acc level
    let moveRight succNode level = do
          let n = (' ':show (key succNode))
          printAcc (forwardPtrs succNode) (acc++n) level
    let onFound succNode level = do
          let n = (' ':show (key succNode))
          printAcc (forwardPtrs succNode) (acc++n) level
    traverse k fwdPtrs curLvl moveDown moveRight onFound moveDown ""


{-
-- | Debug helper. Uses 'unsafeIOToSTM' to print the skip list.
-- For properly ordered output, the key must be greater than the maximum key in
-- the skip list.
printTSkipList :: (Ord k, Show k) => k -> TSkipList k a -> STM ()
printTSkipList k tskip = do
  curLvl   <- readTVar (curLevel tskip)
  forM_ (reverse [1..curLvl]) $ \l -> printAcc (listHead tskip) l
  where
  printAcc fwdPtrs curLvl = do
    let moveDown succNode level = 
          if isNil succNode
            then unsafeIOToSTM $ putStrLn "" -- (show curLvl++": ")
            else do 
              printAcc (forwardPtrs succNode) level
              unsafeIOToSTM $ putStr (' ':show (key succNode))
    let moveRight succNode level = do
          unsafeIOToSTM $ putStr (' ':show (key succNode))
          printAcc (forwardPtrs succNode) level
    let onFound succNode level = do
          unsafeIOToSTM $ putStr (' ':show (key succNode))
          printAcc (forwardPtrs succNode) level
    traverse k fwdPtrs curLvl moveDown moveRight onFound moveDown ()
-}