{-
--------------------------------------------------------------------------------
--
-- Copyright (C) 2008 Martin Sulzmann, Edmund Lam. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

    * Redistributions of source code must retain the above copyright
      notice, this list of conditions and the following disclaimer.

    * Redistributions in binary form must reproduce the above
      copyright notice, this list of conditions and the following
      disclaimer in the documentation and/or other materials provided
      with the distribution.

    * Neither the name of Isaac Jones nor the names of other
      contributors may be used to endorse or promote products derived
      from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-}
 

module MultiSetRewrite.ConcurrentList where

import IO
import GHC.IOBase
import Monad
import Data.IORef
import Control.Concurrent.STM

---------------------------------------------------------
-- API for a thread-safe singly-linked list using CAS
 
data List a = Node { val :: a
                   , verify :: TVar Bool
                   , next :: IORef (List a) }
            | DelNode { verify :: TVar Bool
                      , next :: IORef (List a) }
            | Null
            | Head { next :: IORef (List a) } deriving Eq

{-
The verify field is necessary later when 'atomically' re-verifying
a set of nodes. The following invariant must be satisfied:

- Initially, the field is True
- For each DelNode, the field is False
- For a Node, the field can be either True or False,
  only the 'owner' can set the field to False

None of the functions below access/update the verify field

-}



data ListHandle a = ListHandle { headList :: IORef (IORef (List a)), 
                             tailList :: IORef (IORef (List a)) }


-- we assume a static head pointer, pointing to the first node which must be Head
-- the deleted field of Head is always False, it's only there to make some of the code
-- more uniform
-- tail points to the last node which must be Null


type Iterator a = IORef (IORef (List a))


-------------------------------------------
-- auxilliary functions



while b cmd = if b then do {cmd; while b cmd}
              else return ()

repeatUntil cmd = do { b <- cmd; if b then return ()
                                  else repeatUntil cmd }

repeatUntilCnt cmd = do { (b,c) <- cmd; if b then return c
                                         else repeatUntilCnt cmd }

atomCAS :: Eq a => IORef a -> a -> a -> IO Bool
atomCAS ptr old new =
   atomicModifyIORef ptr (\ cur -> if cur == old
                                   then (new, True)
                                   else (cur, False))

atomicWrite :: IORef a -> a -> IO ()
atomicWrite ptr x =
   atomicModifyIORef ptr (\ _ -> (x,()))


----------------------------------------------
-- functions operating on lists


-- we create a new list
newList :: IO (ListHandle a)
newList = 
   do nullPtr <- newIORef Null
      hd <- newIORef (Head {next = nullPtr })
      hdPtr <- newIORef hd
      tailPtr <- newIORef nullPtr
      return (ListHandle {headList = hdPtr, tailList = tailPtr})


-- we add a new node, by overwriting the null tail node
-- we only need to adjust tailList but not headList because
-- of the static Head
-- we return the location of the newly added node
addToTail :: Eq a => ListHandle a -> a -> IO (IORef (List a))
addToTail (ListHandle {tailList = tailPtrPtr}) x =
   do nullPtr <- newIORef Null
      tPtr <- repeatUntilCnt
               (do tailPtr <- readIORef tailPtrPtr
                   v <- atomically $ newTVar True
                   b <- atomCAS tailPtr Null (Node {val = x, verify = v, next = nullPtr})
                   return (b,tailPtr) )
        -- we atomically update the tail
        -- (by spinning on the tailPtr)
      atomicWrite tailPtrPtr nullPtr
      return tPtr


-- the iterator always points to the PREVIOUS node,
-- recall that there's a static dummy new Head
-- Assumption: iterators are private, 
-- ie they won't be shared among threads
newIterator :: ListHandle a -> IO (Iterator a)
newIterator (ListHandle {headList = hd}) =
  do hdPtr <- readIORef hd
     it <- newIORef hdPtr
     return it


-- assign the rhs iterator's current pointer to the lhs iterator
assignIterator :: Iterator a -> Iterator a -> IO ()
assignIterator lhs rhs =
  do rhsVal <- readIORef rhs
     writeIORef lhs rhsVal
     -- doesn't need to be thread-safe, iterators are not shared

-- we iterate through the list and return the first "not deleted" node
-- we delink deleted nodes
-- there's no need to adjust headList, tailList
-- cause headList has a static Head and
-- tailList points to Null
iterateList :: Eq a => Iterator a -> IO (Maybe (IORef (List a)))
iterateList itPtrPtr = 
  let go prevPtr =
        do do prevNode <- readIORef prevPtr
              let curPtr = next prevNode -- head/node/delnode have all next
              curNode <- readIORef curPtr
              case curNode of
                Node {} -> do writeIORef itPtrPtr curPtr 
                                 -- adjust iterator
                              return (Just curPtr)
                Null -> return Nothing -- reached end of list
                DelNode {next = nextNode} -> 
                         -- atomically delete curNode by setting the next of prevNode to next of curNode
                         -- if this fails we simply move ahead
                        case prevNode of
                          Node {} -> do b <- atomCAS prevPtr prevNode (Node {val = val prevNode, 
                                                                             verify = verify prevNode,
                                                                             next = nextNode})
                                        if b then go prevPtr
                                         else go curPtr
                          Head {} -> do b <- atomCAS prevPtr prevNode (Head {next = nextNode})
                                        if b then go prevPtr 
                                         else go curPtr
                          DelNode {} -> go curPtr    -- if parent deleted simply move ahead

  in do startPtr <- readIORef itPtrPtr
        go startPtr


--printing and counting

printList :: Show a => ListHandle a -> IO ()
printList (ListHandle {headList = ptrPtr}) =
  do startptr <- (
          do ptr <- readIORef ptrPtr
             Head {next = startptr} <- readIORef ptr
             return startptr)
     printListHelp startptr


printListHelp :: Show a => IORef (List a) -> IO ()
printListHelp curNodePtr =
   do { curNode <- readIORef curNodePtr
      ; case curNode of
          Null -> putStr "Nil"
          Node {val = curval, next = curnext} ->
             do { putStr (show curval  ++ " -> ")
                ;  printListHelp curnext }
          DelNode {next = curnext} ->
             do { putStr ("DEAD -> ")
                ;  printListHelp curnext }
      } 

printElement :: Show a => IORef (List a) -> IO ()
printElement curNodePtr =
   do { curNode <- readIORef curNodePtr
      ; case curNode of
          Null -> putStr "Nil"
          DelNode {} ->
             do putStr ("DEL ")
          Node {val = curval} ->
             do putStr $ (show curval) ++ " "
      }


cntList :: Show a => ListHandle a -> IO Int
cntList (ListHandle {headList = ptrPtr}) =
  do startptr <- (
          do ptr <- readIORef ptrPtr
             Head {next = startptr} <- readIORef ptr
             return startptr)
     cntListHelp startptr 0


cntListHelp :: Show a => IORef (List a) -> Int -> IO Int
cntListHelp curNodePtr i =
   do { curNode <- readIORef curNodePtr
      ; case curNode of
          Null -> return i
          Node {next = curnext} -> 
                cntListHelp curnext (i+1)
          DelNode {next = curnext} ->
                cntListHelp curnext (i+1)
      }