{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}

module System.Semaphore
  ( -- * System semaphores

    Semaphore(..), SemaphoreName(..)
  , createSemaphore, freshSemaphore, openSemaphore
  , waitOnSemaphore, tryWaitOnSemaphore
  , WaitId(..)
  , forkWaitOnSemaphoreInterruptible
  , interruptWaitOnSemaphore
  , getSemaphoreValue
  , releaseSemaphore
  , destroySemaphore

  -- * Abstract semaphores

  , AbstractSem(..)
  , withAbstractSem
  ) where

-- base

import Control.Concurrent
import Control.Monad
import Data.List.NonEmpty ( NonEmpty(..) )
import GHC.Exts ( Char(..), Int(..), indexCharOffAddr# )

-- exceptions

import qualified Control.Monad.Catch as MC

#if defined(mingw32_HOST_OS)
-- Win32

import qualified System.Win32.Event     as Win32
  ( createEvent, setEvent
  , waitForSingleObject, waitForMultipleObjects
  , wAIT_OBJECT_0 )
import qualified System.Win32.File      as Win32
  ( closeHandle )
import qualified System.Win32.Process   as Win32
  ( iNFINITE )
import qualified System.Win32.Semaphore as Win32
  ( Semaphore(..), sEMAPHORE_ALL_ACCESS
  , createSemaphore, openSemaphore, releaseSemaphore )
import qualified System.Win32.Time      as Win32
  ( FILETIME(..), getSystemTimeAsFileTime )
import qualified System.Win32.Types     as Win32
  ( HANDLE, errorWin )
#else
-- base

import Foreign.C.Types
  ( CClock(..) )

-- unix

import qualified System.Posix.Semaphore as Posix
  ( Semaphore, OpenSemFlags(..)
  , semOpen, semWaitInterruptible, semTryWait, semThreadWait
  , semGetValue, semPost, semUnlink )
import qualified System.Posix.Files     as Posix
  ( stdFileMode )
import qualified System.Posix.Process   as Posix
  ( ProcessTimes(systemTime), getProcessTimes )
#endif

---------------------------------------

-- System-specific semaphores


newtype SemaphoreName =
  SemaphoreName { SemaphoreName -> String
getSemaphoreName :: String }
  deriving SemaphoreName -> SemaphoreName -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SemaphoreName -> SemaphoreName -> Bool
$c/= :: SemaphoreName -> SemaphoreName -> Bool
== :: SemaphoreName -> SemaphoreName -> Bool
$c== :: SemaphoreName -> SemaphoreName -> Bool
Eq

-- | A system semaphore (POSIX or Win32).

data Semaphore =
  Semaphore
    { Semaphore -> SemaphoreName
semaphoreName :: !SemaphoreName
    , Semaphore -> Semaphore
semaphore     ::
#if defined(mingw32_HOST_OS)
      !Win32.Semaphore
#else
      !Posix.Semaphore
#endif
    }

-- | Create a new semaphore with the given name and initial amount of

-- available resources.

--

-- Throws an error if a semaphore by this name already exists.

createSemaphore :: SemaphoreName
                -> Int -- ^ number of tokens on the semaphore

                -> IO Semaphore
createSemaphore :: SemaphoreName -> Int -> IO Semaphore
createSemaphore (SemaphoreName String
sem_name) Int
init_toks = do
  Either (IO Semaphore) Semaphore
mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_name Int
init_toks
  case Either (IO Semaphore) Semaphore
mb_sem of
    Left  IO Semaphore
err -> IO Semaphore
err
    Right Semaphore
sem -> forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
sem

-- | Create a fresh semaphore with the given amount of tokens.

--

-- Its name will start with the given prefix, but will have a random suffix

-- appended to it.

freshSemaphore :: String -- ^ prefix

               -> Int    -- ^ number of tokens on the semaphore

               -> IO Semaphore
freshSemaphore :: String -> Int -> IO Semaphore
freshSemaphore String
prefix Int
init_toks = do
  NonEmpty String
suffixes <- IO (NonEmpty String)
random_strings
  Int -> NonEmpty String -> IO Semaphore
go Int
0 NonEmpty String
suffixes
  where
    go :: Int -> NonEmpty String -> IO Semaphore
    go :: Int -> NonEmpty String -> IO Semaphore
go Int
i (String
suffix :| [String]
suffs) = do
      Either (IO Semaphore) Semaphore
mb_sem <- String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem (String
prefix forall a. [a] -> [a] -> [a]
++ String
"_" forall a. [a] -> [a] -> [a]
++ String
suffix) Int
init_toks
      case Either (IO Semaphore) Semaphore
mb_sem of
        Right Semaphore
sem -> forall (m :: * -> *) a. Monad m => a -> m a
return Semaphore
sem
        Left  IO Semaphore
err
          | String
next : [String]
nexts <- [String]
suffs
          , Int
i forall a. Ord a => a -> a -> Bool
< Int
32 -- give up after 32 attempts

          -> Int -> NonEmpty String -> IO Semaphore
go (Int
iforall a. Num a => a -> a -> a
+Int
1) (String
next forall a. a -> [a] -> NonEmpty a
:| [String]
nexts)
          | Bool
otherwise
          -> IO Semaphore
err

create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem :: String -> Int -> IO (Either (IO Semaphore) Semaphore)
create_sem String
sem_str Int
init_toks = do
#if defined(mingw32_HOST_OS)
  let toks = fromIntegral init_toks
  mb_sem <- MC.try @_ @MC.SomeException $
    Win32.createSemaphore Nothing toks toks (Just sem_str)
  return $ case mb_sem of
    Right (sem, exists)
      | exists
      -> Left (Win32.errorWin $ "semaphore-compat: semaphore " ++ sem_str ++ " already exists")
      | otherwise
      -> Right $ mk_sem sem
    Left err
      -> Left $ MC.throwM err
#else
  let flags :: OpenSemFlags
flags =
        Posix.OpenSemFlags
          { semCreate :: Bool
Posix.semCreate    = Bool
True
          , semExclusive :: Bool
Posix.semExclusive = Bool
True }
  Either SomeException Semaphore
mb_sem <- forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try @_ @MC.SomeException forall a b. (a -> b) -> a -> b
$
    String -> OpenSemFlags -> FileMode -> Int -> IO Semaphore
Posix.semOpen String
sem_str OpenSemFlags
flags FileMode
Posix.stdFileMode Int
init_toks
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Either SomeException Semaphore
mb_sem of
    Left  SomeException
err -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
MC.throwM SomeException
err
    Right Semaphore
sem -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Semaphore -> Semaphore
mk_sem Semaphore
sem
#endif
  where
    sem_nm :: SemaphoreName
sem_nm = String -> SemaphoreName
SemaphoreName String
sem_str
    mk_sem :: Semaphore -> Semaphore
mk_sem Semaphore
sem =
      Semaphore
        { semaphore :: Semaphore
semaphore     = Semaphore
sem
        , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
sem_nm }

-- | Open a semaphore with the given name.

--

-- If no such semaphore exists, throws an error.

openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore :: SemaphoreName -> IO Semaphore
openSemaphore nm :: SemaphoreName
nm@(SemaphoreName String
sem_name) = do
#if defined(mingw32_HOST_OS)
  sem <- Win32.openSemaphore Win32.sEMAPHORE_ALL_ACCESS True sem_name
#else
  let
    flags :: OpenSemFlags
flags = Posix.OpenSemFlags
          { semCreate :: Bool
Posix.semCreate    = Bool
False
          , semExclusive :: Bool
Posix.semExclusive = Bool
False }
  Semaphore
sem <- String -> OpenSemFlags -> FileMode -> Int -> IO Semaphore
Posix.semOpen String
sem_name OpenSemFlags
flags FileMode
Posix.stdFileMode Int
0
#endif
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
    Semaphore
      { semaphore :: Semaphore
semaphore     = Semaphore
sem
      , semaphoreName :: SemaphoreName
semaphoreName = SemaphoreName
nm }

-- | Indefinitely wait on a semaphore.

--

-- If you want to be able to cancel a wait operation, use

-- 'forkWaitOnSemaphoreInterruptible' instead.

waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore :: Semaphore -> IO ()
waitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    () <$ Win32.waitForSingleObject (Win32.semaphoreHandle sem) Win32.iNFINITE
#else
  Semaphore -> IO ()
Posix.semThreadWait Semaphore
sem
#endif

-- | Try to obtain a token from the semaphore, without blocking.

--

-- Immediately returns 'False' if no resources are available.

tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore :: Semaphore -> IO Bool
tryWaitOnSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    wait_res <- Win32.waitForSingleObject (Win32.semaphoreHandle sem) 0
    return $ wait_res == Win32.wAIT_OBJECT_0
#else
  Semaphore -> IO Bool
Posix.semTryWait Semaphore
sem
#endif

-- | Release a semaphore: add @n@ to its internal counter.

--

-- No-op when `n <= 0`.

releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore :: Semaphore -> Int -> IO ()
releaseSemaphore (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) Int
n
  | Int
n forall a. Ord a => a -> a -> Bool
<= Int
0
  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
  | Bool
otherwise
  = forall (m :: * -> *) a. MonadMask m => m a -> m a
MC.mask_ forall a b. (a -> b) -> a -> b
$ do
#if defined(mingw32_HOST_OS)
    void $ Win32.releaseSemaphore sem (fromIntegral n)
#else
    forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ Int
n (Semaphore -> IO ()
Posix.semPost Semaphore
sem)
#endif

-- | Destroy the given semaphore.

destroySemaphore :: Semaphore -> IO ()
destroySemaphore :: Semaphore -> IO ()
destroySemaphore Semaphore
sem =
#if defined(mingw32_HOST_OS)
  Win32.closeHandle (Win32.semaphoreHandle $ semaphore sem)
#else
  String -> IO ()
Posix.semUnlink (SemaphoreName -> String
getSemaphoreName forall a b. (a -> b) -> a -> b
$ Semaphore -> SemaphoreName
semaphoreName Semaphore
sem)
#endif

-- | Query the current semaphore value (how many tokens it has available).

--

-- This is mainly for debugging use, as it is easy to introduce race conditions

-- when nontrivial program logic depends on the value returned by this function.

getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue :: Semaphore -> IO Int
getSemaphoreValue (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem }) =
#if defined(mingw32_HOST_OS)
  MC.mask_ $ do
    wait_res <- Win32.waitForSingleObject (Win32.semaphoreHandle sem) 0
    if wait_res == Win32.wAIT_OBJECT_0
      -- We were able to acquire a resource from the semaphore without waiting:

      -- release it immediately, thus obtaining the total number of available

      -- resources.

    then
      (+1) . fromIntegral <$> Win32.releaseSemaphore sem 1
    else
      return 0
#else
  Semaphore -> IO Int
Posix.semGetValue Semaphore
sem
#endif

-- | 'WaitId' stores the information we need to cancel a thread

-- which is waiting on a semaphore.

--

-- See 'forkWaitOnSemaphoreInterruptible' and 'interruptWaitOnSemaphore'.

data WaitId = WaitId { WaitId -> ThreadId
waitingThreadId :: ThreadId
#if defined(mingw32_HOST_OS)
                     , cancelHandle    :: Win32.HANDLE
#endif
                     }

-- | Spawn a thread that waits on the given semaphore.

--

-- In this thread, asynchronous exceptions will be masked.

--

-- The waiting operation can be interrupted using the

-- 'interruptWaitOnSemaphore' function.

forkWaitOnSemaphoreInterruptible
  :: Semaphore
  -> ( Either MC.SomeException Bool -> IO () ) -- ^ wait result action

  -> IO WaitId
forkWaitOnSemaphoreInterruptible :: Semaphore -> (Either SomeException Bool -> IO ()) -> IO WaitId
forkWaitOnSemaphoreInterruptible
  (Semaphore { semaphore :: Semaphore -> Semaphore
semaphore = Semaphore
sem })
  Either SomeException Bool -> IO ()
wait_result_action = do
#if defined(mingw32_HOST_OS)
    cancelHandle <- Win32.createEvent Nothing True False ""
#endif
    let
      interruptible_wait :: IO Bool
      interruptible_wait :: IO Bool
interruptible_wait =
#if defined(mingw32_HOST_OS)
        -- Windows: wait on both the handle used for cancelling the wait

        -- and on the semaphore.

          do
            wait_res <-
              Win32.waitForMultipleObjects
                [ Win32.semaphoreHandle sem
                , cancelHandle ]
                False -- False <=> WaitAny

                Win32.iNFINITE
            return $ wait_res == Win32.wAIT_OBJECT_0
            -- Only in the case that the wait result is WAIT_OBJECT_0 will

            -- we have succeeded in obtaining a token from the semaphore.

#else
        -- POSIX: use the 'semWaitInterruptible' interruptible FFI call

        -- that can be interrupted when we send a killThread signal.

          Semaphore -> IO Bool
Posix.semWaitInterruptible Semaphore
sem
#endif
    ThreadId
waitingThreadId <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadMask m => m a -> m a
MC.mask_ forall a b. (a -> b) -> a -> b
$ do
      Either SomeException Bool
wait_res <- forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
MC.try IO Bool
interruptible_wait
      Either SomeException Bool -> IO ()
wait_result_action Either SomeException Bool
wait_res
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ WaitId { ThreadId
waitingThreadId :: ThreadId
waitingThreadId :: ThreadId
.. }

-- | Interrupt a semaphore wait operation initiated by

-- 'forkWaitOnSemaphoreInterruptible'.

interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore :: WaitId -> IO ()
interruptWaitOnSemaphore ( WaitId { ThreadId
waitingThreadId :: ThreadId
waitingThreadId :: WaitId -> ThreadId
.. } ) = do
#if defined(mingw32_HOST_OS)
  Win32.setEvent cancelHandle
    -- On Windows, we signal to stop waiting.

#endif
  ThreadId -> IO ()
killThread ThreadId
waitingThreadId
    -- On POSIX, killing the thread will cancel the wait on the semaphore

    -- due to the FFI call being interruptible ('semWaitInterruptible').


---------------------------------------

-- Abstract semaphores


-- | Abstraction over the operations of a semaphore.

data AbstractSem =
  AbstractSem
    { AbstractSem -> IO ()
acquireSem :: IO ()
    , AbstractSem -> IO ()
releaseSem :: IO ()
    }

withAbstractSem :: AbstractSem -> IO b -> IO b
withAbstractSem :: forall b. AbstractSem -> IO b -> IO b
withAbstractSem AbstractSem
sem = forall (m :: * -> *) a c b. MonadMask m => m a -> m c -> m b -> m b
MC.bracket_ (AbstractSem -> IO ()
acquireSem AbstractSem
sem) (AbstractSem -> IO ()
releaseSem AbstractSem
sem)

---------------------------------------

-- Utility


iToBase62 :: Int -> String
iToBase62 :: Int -> String
iToBase62 Int
m = Int -> String -> String
go Int
m' String
""
  where
    m' :: Int
m'
      | Int
m forall a. Eq a => a -> a -> Bool
== forall a. Bounded a => a
minBound
      = forall a. Bounded a => a
maxBound
      | Bool
otherwise
      = forall a. Num a => a -> a
abs Int
m
    go :: Int -> String -> String
go Int
n String
cs | Int
n forall a. Ord a => a -> a -> Bool
< Int
62
            = let !c :: Char
c = Int -> Char
chooseChar62 Int
n
              in Char
c forall a. a -> [a] -> [a]
: String
cs
            | Bool
otherwise
            = let !(!Int
q, Int
r) = forall a. Integral a => a -> a -> (a, a)
quotRem Int
n Int
62
                  !c :: Char
c       = Int -> Char
chooseChar62 Int
r
              in Int -> String -> String
go Int
q (Char
c forall a. a -> [a] -> [a]
: String
cs)

    chooseChar62 :: Int -> Char
    {-# INLINE chooseChar62 #-}
    chooseChar62 :: Int -> Char
chooseChar62 (I# Int#
n) = Char# -> Char
C# (Addr# -> Int# -> Char#
indexCharOffAddr# Addr#
chars62 Int#
n)
    chars62 :: Addr#
chars62 = Addr#
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"#

random_strings :: IO (NonEmpty String)
random_strings :: IO (NonEmpty String)
random_strings = do
#if defined(mingw32_HOST_OS)
  Win32.FILETIME t <- Win32.getSystemTimeAsFileTime
#else
  CClock Int64
t <- ProcessTimes -> CClock
Posix.systemTime forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO ProcessTimes
Posix.getProcessTimes
#endif
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ( \ Int
i -> Int -> String
iToBase62 (Int
i forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
t) ) (Int
0 forall a. a -> [a] -> NonEmpty a
:| [Int
1..])