{-# LANGUAGE BangPatterns #-}
module Data.STM.LinkedList.Internal where

import Control.Concurrent.STM
import Data.Maybe (isJust, isNothing)
import System.IO (fixIO)

-- | List handle.  Used for insertion and traversal starting at the beginning
-- or end of the list.
newtype LinkedList a = LinkedList (Node a)

-- | Unwrap the list head, a special 'Node' with the following properties:
--
-- * @'next' . 'listHead' == 'start'@
--
-- * @'prev' . 'listHead' == 'end'@
--
-- * @'insertBefore' v . 'listHead' == 'append' v@
--
-- * @'insertAfter' v . 'listHead' == 'prepend' v@
--
-- * @'value' . 'listHead' ==> /error/@
--
-- * @'delete' . 'listHead' ==> /error/@
listHead :: LinkedList a -> Node a
listHead (LinkedList h) = h

-- | List node.  Used for insertion, traversal, and removal starting at a given
-- item in the list.
--
-- A Node contains an immutable value of type @a@, and 'TVar's that point to
-- the previous and next nodes.
--
-- Node equality can be likened to pointer equality in C.  Two Node values are
-- considered equal if and only if they were created with the same insertion
-- operation.
data Node a
    = Node
        { nodePrev  :: NodePtr a
        , nodeNext  :: NodePtr a
        , nodeValue :: Maybe a
            -- ^ 'Nothing' if this is the list head.
        }

type NodePtr a = TVar (Node a)

instance Eq (Node a) where
    a == b = nodeNext a == nodeNext b

-- | Extract the value of a node.
value :: Node a -> a
value node = case nodeValue node of
                 Just v  -> v
                 Nothing -> error "LinkedList.value: list head"

-- | /O(1)/. Is the list empty?
null :: LinkedList a -> STM Bool
null (LinkedList list_head) = do
    first <- readTVar $ nodeNext list_head
    return $ isNothing $ nodeValue first

-- | /O(n)/. Count the number of items in the list.
length :: LinkedList a -> STM Int
length (LinkedList list_head) = foldlHelper (\a _ -> a + 1) 0 nodeNext list_head

-- | /O(1)/. Create an empty linked list.
empty :: STM (LinkedList a)
empty = do
    prev_ptr <- newTVar undefined
    next_ptr <- newTVar undefined
    let node = Node prev_ptr next_ptr Nothing
    writeTVar prev_ptr node
    writeTVar next_ptr node
    return $ LinkedList node

-- | /O(1)/. Version of 'empty' that can be used in the 'IO' monad.
emptyIO :: IO (LinkedList a)
emptyIO = do
    node <- fixIO $ \node -> do
        prev_ptr <- newTVarIO node
        next_ptr <- newTVarIO node
        return (Node prev_ptr next_ptr Nothing)
    return $ LinkedList node

-- | Insert a node between two adjacent nodes.
insertBetween :: a -> Node a -> Node a -> STM (Node a)
insertBetween v left right = do
    prev_ptr <- newTVar left
    next_ptr <- newTVar right
    let node = Node prev_ptr next_ptr (Just v)
    writeTVar (nodeNext left) node
    writeTVar (nodePrev right) node
    return node

-- | /O(1)/. Add a node to the beginning of a linked list.
prepend :: a -> LinkedList a -> STM (Node a)
prepend v (LinkedList list_head) = do
    right <- readTVar $ nodeNext list_head
    insertBetween v list_head right

-- | /O(1)/. Add a node to the end of a linked list.
append :: a -> LinkedList a -> STM (Node a)
append v (LinkedList list_head) = do
    left <- readTVar $ nodePrev list_head
    insertBetween v left list_head

-- | /O(1)/. Insert an item before the given node.
insertBefore :: a -> Node a -> STM (Node a)
insertBefore v node = do
    left <- readTVar $ nodePrev node
    if left == node && isJust (nodeValue node)
        then error "LinkedList.insertBefore: node removed from list"
        else insertBetween v left node

-- | /O(1)/. Insert an item after the given node.
insertAfter :: a -> Node a -> STM (Node a)
insertAfter v node = do
    right <- readTVar $ nodeNext node
    if right == node && isJust (nodeValue node)
        then error "LinkedList.insertAfter: node removed from list"
        else insertBetween v node right

-- | /O(1)/. Remove a node from whatever 'LinkedList' it is in.  If the node
-- has already been removed, this is a no-op.
delete :: Node a -> STM ()
delete node
    | isNothing (nodeValue node) =
        error "LinkedList.delete: list head"
    | otherwise = do
        left <- readTVar $ nodePrev node
        right <- readTVar $ nodeNext node
        writeTVar (nodeNext left) right
        writeTVar (nodePrev right) left

        -- Link list node to itself so subsequent 'delete' calls will be harmless.
        writeTVar (nodePrev node) node
        writeTVar (nodeNext node) node

stepHelper :: (Node a -> NodePtr a) -> Node a -> STM (Maybe (Node a))
stepHelper step node = do
    node' <- readTVar $ step node
    if node' == node
        then return Nothing
        else case nodeValue node' of
            Just _  -> return $ Just node'
            Nothing -> return Nothing

-- | /O(1)/. Get the previous node.  Return 'Nothing' if this is the first item,
-- or if this node has been 'delete'd from its list.
prev :: Node a -> STM (Maybe (Node a))
prev = stepHelper nodePrev

-- | /O(1)/. Get the next node.  Return 'Nothing' if this is the last item,
-- or if this node has been 'delete'd from its list.
next :: Node a -> STM (Maybe (Node a))
next = stepHelper nodeNext

-- | /O(1)/. Get the node corresponding to the first item of the list.  Return
-- 'Nothing' if the list is empty.
start :: LinkedList a -> STM (Maybe (Node a))
start = next . listHead

-- | /O(1)/. Get the node corresponding to the last item of the list.  Return
-- 'Nothing' if the list is empty.
end :: LinkedList a -> STM (Maybe (Node a))
end = prev . listHead

-- | Traverse list nodes with a fold function.  The traversal terminates when
-- the list head is reached.
--
-- This is strict in the accumulator.
foldlHelper :: (a -> b -> a)            -- ^ Fold function
            -> a                        -- ^ Initial value
            -> (Node b -> NodePtr b)    -- ^ Step function ('nodePrev' or 'nodeNext')
            -> Node b                   -- ^ Starting node.  This node's value is not used!
            -> STM a
foldlHelper f z nodeStep start_node =
        loop z start_node
    where
        loop !accum node = do
            node' <- readTVar $ nodeStep node
            case nodeValue node' of
                Nothing -> return accum
                Just v  -> loop (f accum v) node'

-- | /O(n)/. Return all of the items in a 'LinkedList'.
toList :: LinkedList a -> STM [a]
toList (LinkedList list_head) = foldlHelper (flip (:)) [] nodePrev list_head

-- | /O(n)/. Return all of the items in a 'LinkedList', in reverse order.
toListRev :: LinkedList a -> STM [a]
toListRev (LinkedList list_head) = foldlHelper (flip (:)) [] nodeNext list_head