{-# Language GADTs #-}
module Network.AMQP.ChannelAllocator where

import qualified Data.Vector.Mutable as V
import Control.Exception (throwIO)
import Data.Word
import Data.Bits

import Network.AMQP.Types

data ChannelAllocator = ChannelAllocator Int -- highest permitted channel id

                                         (V.IOVector Word64)


newChannelAllocator :: Int -> IO ChannelAllocator
newChannelAllocator :: Int -> IO ChannelAllocator
newChannelAllocator Int
maxChannel =
    Int -> IOVector Word64 -> ChannelAllocator
ChannelAllocator Int
maxChannel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
V.replicate Int
1024 Word64
0

allocateChannel :: ChannelAllocator -> IO Int
allocateChannel :: ChannelAllocator -> IO Int
allocateChannel (ChannelAllocator Int
maxChannel IOVector Word64
c) = do
    Maybe Int
maybeIx <- IOVector Word64 -> IO (Maybe Int)
findFreeIndex IOVector Word64
c
    case Maybe Int
maybeIx of
        Just Int
chunk -> do
            Word64
word <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
V.read IOVector Word64
c Int
chunk
            let offset :: Int
offset = Word64 -> Int
findUnsetBit Word64
word
            let channelID :: Int
channelID = Int
chunkforall a. Num a => a -> a -> a
*Int
64 forall a. Num a => a -> a -> a
+ Int
offset
            if Int
channelID forall a. Ord a => a -> a -> Bool
> Int
maxChannel
                then forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ Int -> AMQPException
AllChannelsAllocatedException Int
maxChannel
                else do
                    forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
V.write IOVector Word64
c Int
chunk (forall a. Bits a => a -> Int -> a
setBit Word64
word Int
offset)
                    forall (m :: * -> *) a. Monad m => a -> m a
return Int
channelID
        Maybe Int
Nothing -> forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ Int -> AMQPException
AllChannelsAllocatedException Int
maxChannel

freeChannel :: ChannelAllocator -> Int -> IO Bool
freeChannel :: ChannelAllocator -> Int -> IO Bool
freeChannel (ChannelAllocator Int
_maxChannel IOVector Word64
c) Int
ix = do
    let (Int
chunk, Int
offset) = forall a. Integral a => a -> a -> (a, a)
divMod Int
ix Int
64
    Word64
word <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
V.read IOVector Word64
c Int
chunk
    if forall a. Bits a => a -> Int -> Bool
testBit Word64
word Int
offset
        then do
            forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
V.write IOVector Word64
c Int
chunk forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
clearBit Word64
word Int
offset
            forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        else forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

findUnsetBit :: Word64 -> Int
findUnsetBit :: Word64 -> Int
findUnsetBit Word64
w = Int -> Int
go Int
0
  where
    go :: Int -> Int
go Int
65 = forall a. HasCallStack => [Char] -> a
error [Char]
"findUnsetBit"
    go Int
ix | Bool -> Bool
not (forall a. Bits a => a -> Int -> Bool
testBit Word64
w Int
ix) = Int
ix
    go Int
ix = Int -> Int
go (Int
ixforall a. Num a => a -> a -> a
+Int
1)

findFreeIndex :: V.IOVector Word64 -> IO (Maybe Int)
findFreeIndex :: IOVector Word64 -> IO (Maybe Int)
findFreeIndex IOVector Word64
vec = Int -> IO (Maybe Int)
go Int
0
  where
    -- TODO: make this faster

    go :: Int -> IO (Maybe Int)
go Int
1024 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
    go Int
ix = do
        Word64
v <- forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
V.read IOVector Word64
vec Int
ix
        if Word64
v forall a. Eq a => a -> a -> Bool
/= Word64
0xffffffffffffffff
            then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Int
ix
            else Int -> IO (Maybe Int)
go forall a b. (a -> b) -> a -> b
$! Int
ixforall a. Num a => a -> a -> a
+Int
1