{-# LANGUAGE
     BangPatterns
   , DataKinds
   , KindSignatures
   , LambdaCase
   #-}

{-# LANGUAGE NoIncoherentInstances #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE NoUndecidableInstances #-}

module Vivid.SCServer.State (
     BufferId(..)
   , NodeId(..)
   , SyncId(..)

   , SCServerState(..)

   , ConnProtocol(..)

   , setServerClientId

   , setServerMaxBufferIds

   , numberOfSyncIdsToDrop

   , makeEmptySCServerState

   -- We might not need to export these (or the global equivalents) at all:
   , getNextAvailable'
   , getNextAvailables'
   ) where

import Vivid.OSC (OSC)
import Vivid.SC.Server.Types
import Vivid.SynthDef.Types

import Network.Socket (Socket)

import Control.Concurrent (ThreadId)
import Control.Concurrent.MVar
import Control.Concurrent.STM -- (readTVar, atomically, writeTVar, newTVar, TVar, TMVar)
import Control.Monad (when)
import Data.Bits
import Data.Int (Int32)
-- import Data.IORef
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)
import Prelude

data SCServerState
   = SCServerState
   -- We use 'IORef Maybe's instead of MVars so we can use weak pointer
   --   finalizers with older versions of GHC:
  { SCServerState -> TVar Bool
_scServerState_socketConnectStarted :: TVar Bool
  , SCServerState -> TMVar Socket
_scServerState_socket :: !(TMVar Socket) -- !(TVar (Maybe Socket))
  , SCServerState -> TMVar ThreadId
_scServerState_listener :: !(TMVar ThreadId) -- !(TVar (Maybe ThreadId))

  , SCServerState -> TVar [BufferId]
_scServerState_availableBufferIds :: !(TVar [BufferId])
  , SCServerState -> TVar Int32
_scServerState_maxBufIds :: !(TVar Int32)
  , SCServerState -> TVar [NodeId]
_scServerState_availableNodeIds :: !(TVar [NodeId])
  , SCServerState -> TVar [SyncId]
_scServerState_availableSyncIds :: !(TVar [SyncId])
  , SCServerState -> TVar (Map SyncId (MVar ()))
_scServerState_syncIdMailboxes :: !(TVar (Map SyncId (MVar ())))
  , SCServerState -> TVar (OSC -> IO ())
_scServerState_serverMessageFunction :: !(TVar (OSC -> IO ()))
  , SCServerState -> TVar (Set (SDName, Int))
_scServerState_definedSDs :: !(TVar (Set (SDName, Int))) -- Int is the hash
  , SCServerState -> TVar ConnProtocol
_scServerState_connProtocol :: TVar ConnProtocol -- This doesn't change after boot, but we could e.g. disconnect and reconnect after boot
  }

data ConnProtocol
   = ConnProtocol_UDP
   | ConnProtocol_TCP
 deriving (Int -> ConnProtocol -> ShowS
[ConnProtocol] -> ShowS
ConnProtocol -> String
(Int -> ConnProtocol -> ShowS)
-> (ConnProtocol -> String)
-> ([ConnProtocol] -> ShowS)
-> Show ConnProtocol
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnProtocol] -> ShowS
$cshowList :: [ConnProtocol] -> ShowS
show :: ConnProtocol -> String
$cshow :: ConnProtocol -> String
showsPrec :: Int -> ConnProtocol -> ShowS
$cshowsPrec :: Int -> ConnProtocol -> ShowS
Show, ReadPrec [ConnProtocol]
ReadPrec ConnProtocol
Int -> ReadS ConnProtocol
ReadS [ConnProtocol]
(Int -> ReadS ConnProtocol)
-> ReadS [ConnProtocol]
-> ReadPrec ConnProtocol
-> ReadPrec [ConnProtocol]
-> Read ConnProtocol
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ConnProtocol]
$creadListPrec :: ReadPrec [ConnProtocol]
readPrec :: ReadPrec ConnProtocol
$creadPrec :: ReadPrec ConnProtocol
readList :: ReadS [ConnProtocol]
$creadList :: ReadS [ConnProtocol]
readsPrec :: Int -> ReadS ConnProtocol
$creadsPrec :: Int -> ReadS ConnProtocol
Read, ConnProtocol -> ConnProtocol -> Bool
(ConnProtocol -> ConnProtocol -> Bool)
-> (ConnProtocol -> ConnProtocol -> Bool) -> Eq ConnProtocol
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnProtocol -> ConnProtocol -> Bool
$c/= :: ConnProtocol -> ConnProtocol -> Bool
== :: ConnProtocol -> ConnProtocol -> Bool
$c== :: ConnProtocol -> ConnProtocol -> Bool
Eq, Eq ConnProtocol
Eq ConnProtocol
-> (ConnProtocol -> ConnProtocol -> Ordering)
-> (ConnProtocol -> ConnProtocol -> Bool)
-> (ConnProtocol -> ConnProtocol -> Bool)
-> (ConnProtocol -> ConnProtocol -> Bool)
-> (ConnProtocol -> ConnProtocol -> Bool)
-> (ConnProtocol -> ConnProtocol -> ConnProtocol)
-> (ConnProtocol -> ConnProtocol -> ConnProtocol)
-> Ord ConnProtocol
ConnProtocol -> ConnProtocol -> Bool
ConnProtocol -> ConnProtocol -> Ordering
ConnProtocol -> ConnProtocol -> ConnProtocol
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ConnProtocol -> ConnProtocol -> ConnProtocol
$cmin :: ConnProtocol -> ConnProtocol -> ConnProtocol
max :: ConnProtocol -> ConnProtocol -> ConnProtocol
$cmax :: ConnProtocol -> ConnProtocol -> ConnProtocol
>= :: ConnProtocol -> ConnProtocol -> Bool
$c>= :: ConnProtocol -> ConnProtocol -> Bool
> :: ConnProtocol -> ConnProtocol -> Bool
$c> :: ConnProtocol -> ConnProtocol -> Bool
<= :: ConnProtocol -> ConnProtocol -> Bool
$c<= :: ConnProtocol -> ConnProtocol -> Bool
< :: ConnProtocol -> ConnProtocol -> Bool
$c< :: ConnProtocol -> ConnProtocol -> Bool
compare :: ConnProtocol -> ConnProtocol -> Ordering
$ccompare :: ConnProtocol -> ConnProtocol -> Ordering
$cp1Ord :: Eq ConnProtocol
Ord)

setServerClientId :: SCServerState -> Int32 -> IO ()
setServerClientId :: SCServerState -> Int32 -> IO ()
setServerClientId SCServerState
serverState Int32
clientId = do
   Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int32
clientId Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
0 Bool -> Bool -> Bool
|| Int32
clientId Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
> Int32
31) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      String -> IO ()
forall a. HasCallStack => String -> a
error String
"client id must be betw 0 and 31"
   STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar [NodeId] -> [NodeId] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (SCServerState -> TVar [NodeId]
_scServerState_availableNodeIds SCServerState
serverState) ([NodeId] -> STM ()) -> [NodeId] -> STM ()
forall a b. (a -> b) -> a -> b
$
      -- The client id is the first 5 bits of a positive int:
      -- Note the incrementing gets weird once we hit the (.&.) -- should
      -- fix if anyone plans to use more than 33 million nodes
      (Int32 -> NodeId) -> [Int32] -> [NodeId]
forall a b. (a -> b) -> [a] -> [b]
map Int32 -> NodeId
f [Int32
1000..]
 where
   f :: Int32 -> NodeId
   f :: Int32 -> NodeId
f Int32
nodeNum = Int32 -> NodeId
NodeId (Int32 -> NodeId) -> Int32 -> NodeId
forall a b. (a -> b) -> a -> b
$
         ((Int32
clientId Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftL` ((Int32 -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Int32
nodeNumInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
5)Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.|.) (Int32 -> Int32) -> Int32 -> Int32
forall a b. (a -> b) -> a -> b
$
            ((Int32
forall a. Bounded a => a
maxBound Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
5) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.&. Int32
nodeNum)

numberOfSyncIdsToDrop :: Int
numberOfSyncIdsToDrop :: Int
numberOfSyncIdsToDrop = Int
10000

makeEmptySCServerState :: IO SCServerState
-- We don't do this with 'atomically' because you can't put 'atomically' in
--   'unsafePerformIO' (or, apparently you can with the "!_ =" hack I was
--   doing, but let's do the recommended way):
makeEmptySCServerState :: IO SCServerState
makeEmptySCServerState = do -- atomically $ do
   TVar Bool
sockConnectStarted <- Bool -> IO (TVar Bool)
forall a. a -> IO (TVar a)
newTVarIO Bool
False
   TMVar Socket
sockIORef <- IO (TMVar Socket)
forall a. IO (TMVar a)
newEmptyTMVarIO -- newTVar Nothing -- newIORef Nothing
   TMVar ThreadId
listenerIORef <- IO (TMVar ThreadId)
forall a. IO (TMVar a)
newEmptyTMVarIO -- newTVar Nothing -- newIORef Nothing

   TVar [BufferId]
availBufIds <- [BufferId] -> IO (TVar [BufferId])
forall a. a -> IO (TVar a)
newTVarIO ([BufferId] -> IO (TVar [BufferId]))
-> [BufferId] -> IO (TVar [BufferId])
forall a b. (a -> b) -> a -> b
$ Int -> [BufferId] -> [BufferId]
forall a. Int -> [a] -> [a]
drop Int
512 ([BufferId] -> [BufferId]) -> [BufferId] -> [BufferId]
forall a b. (a -> b) -> a -> b
$ (Int32 -> BufferId) -> [Int32] -> [BufferId]
forall a b. (a -> b) -> [a] -> [b]
map Int32 -> BufferId
BufferId [Int32
0..]
   -- these'll be allocated when we connect (and get a clientId):
   TVar [NodeId]
availNodeIds <- [NodeId] -> IO (TVar [NodeId])
forall a. a -> IO (TVar a)
newTVarIO ([NodeId] -> IO (TVar [NodeId])) -> [NodeId] -> IO (TVar [NodeId])
forall a b. (a -> b) -> a -> b
$ (Int32 -> NodeId) -> [Int32] -> [NodeId]
forall a b. (a -> b) -> [a] -> [b]
map (Int32 -> NodeId
NodeId (Int32 -> NodeId) -> (Int32 -> Int32) -> Int32 -> NodeId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int32
1 Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftL` Int
26) Int32 -> Int32 -> Int32
forall a. Bits a => a -> a -> a
.|.)) [Int32
1000..]
   TVar Int32
maxBufIds <- Int32 -> IO (TVar Int32)
forall a. a -> IO (TVar a)
newTVarIO Int32
1024
   TVar [SyncId]
syncIds <- [SyncId] -> IO (TVar [SyncId])
forall a. a -> IO (TVar a)
newTVarIO ([SyncId] -> IO (TVar [SyncId])) -> [SyncId] -> IO (TVar [SyncId])
forall a b. (a -> b) -> a -> b
$ Int -> [SyncId] -> [SyncId]
forall a. Int -> [a] -> [a]
drop Int
numberOfSyncIdsToDrop ([SyncId] -> [SyncId]) -> [SyncId] -> [SyncId]
forall a b. (a -> b) -> a -> b
$ (Int32 -> SyncId) -> [Int32] -> [SyncId]
forall a b. (a -> b) -> [a] -> [b]
map Int32 -> SyncId
SyncId [Int32
0..]
   TVar (Map SyncId (MVar ()))
syncMailboxes <- Map SyncId (MVar ()) -> IO (TVar (Map SyncId (MVar ())))
forall a. a -> IO (TVar a)
newTVarIO (Map SyncId (MVar ()) -> IO (TVar (Map SyncId (MVar ()))))
-> Map SyncId (MVar ()) -> IO (TVar (Map SyncId (MVar ())))
forall a b. (a -> b) -> a -> b
$ Map SyncId (MVar ())
forall k a. Map k a
Map.empty
   TVar (OSC -> IO ())
serverMessageFunction <- (OSC -> IO ()) -> IO (TVar (OSC -> IO ()))
forall a. a -> IO (TVar a)
newTVarIO ((OSC -> IO ()) -> IO (TVar (OSC -> IO ())))
-> (OSC -> IO ()) -> IO (TVar (OSC -> IO ()))
forall a b. (a -> b) -> a -> b
$ \OSC
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
   TVar (Set (SDName, Int))
definedSDs <- Set (SDName, Int) -> IO (TVar (Set (SDName, Int)))
forall a. a -> IO (TVar a)
newTVarIO (Set (SDName, Int) -> IO (TVar (Set (SDName, Int))))
-> Set (SDName, Int) -> IO (TVar (Set (SDName, Int)))
forall a b. (a -> b) -> a -> b
$ Set (SDName, Int)
forall a. Set a
Set.empty
   TVar ConnProtocol
connProtocolVar <- ConnProtocol -> IO (TVar ConnProtocol)
forall a. a -> IO (TVar a)
newTVarIO ConnProtocol
ConnProtocol_UDP

   SCServerState -> IO SCServerState
forall (m :: * -> *) a. Monad m => a -> m a
return (SCServerState -> IO SCServerState)
-> SCServerState -> IO SCServerState
forall a b. (a -> b) -> a -> b
$ SCServerState :: TVar Bool
-> TMVar Socket
-> TMVar ThreadId
-> TVar [BufferId]
-> TVar Int32
-> TVar [NodeId]
-> TVar [SyncId]
-> TVar (Map SyncId (MVar ()))
-> TVar (OSC -> IO ())
-> TVar (Set (SDName, Int))
-> TVar ConnProtocol
-> SCServerState
SCServerState
          { _scServerState_socketConnectStarted :: TVar Bool
_scServerState_socketConnectStarted = TVar Bool
sockConnectStarted
          , _scServerState_socket :: TMVar Socket
_scServerState_socket = TMVar Socket
sockIORef
          , _scServerState_listener :: TMVar ThreadId
_scServerState_listener = TMVar ThreadId
listenerIORef
          , _scServerState_availableBufferIds :: TVar [BufferId]
_scServerState_availableBufferIds = TVar [BufferId]
availBufIds
          , _scServerState_maxBufIds :: TVar Int32
_scServerState_maxBufIds = TVar Int32
maxBufIds
          , _scServerState_availableNodeIds :: TVar [NodeId]
_scServerState_availableNodeIds = TVar [NodeId]
availNodeIds
          , _scServerState_availableSyncIds :: TVar [SyncId]
_scServerState_availableSyncIds = TVar [SyncId]
syncIds
          , _scServerState_syncIdMailboxes :: TVar (Map SyncId (MVar ()))
_scServerState_syncIdMailboxes = TVar (Map SyncId (MVar ()))
syncMailboxes
          , _scServerState_serverMessageFunction :: TVar (OSC -> IO ())
_scServerState_serverMessageFunction = TVar (OSC -> IO ())
serverMessageFunction
          , _scServerState_definedSDs :: TVar (Set (SDName, Int))
_scServerState_definedSDs = TVar (Set (SDName, Int))
definedSDs
          , _scServerState_connProtocol :: TVar ConnProtocol
_scServerState_connProtocol = TVar ConnProtocol
connProtocolVar
          }

-- | If you've started the SC server with a non-default number of buffer ids,
--   (e.g. with the \"-b\" argument), you can reflect that here
-- 
--   Note that the buffer ids start at 512, to not clash with any that
--   another client (e.g. sclang) has allocated
setServerMaxBufferIds :: SCServerState -> Int32 -> IO ()
setServerMaxBufferIds :: SCServerState -> Int32 -> IO ()
setServerMaxBufferIds SCServerState
serverState Int32
newMax =
   STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
      TVar Int32 -> Int32 -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar (SCServerState -> TVar Int32
_scServerState_maxBufIds SCServerState
serverState) Int32
newMax

getNextAvailable' :: SCServerState -> (SCServerState -> TVar [a]) -> IO a
getNextAvailable' :: SCServerState -> (SCServerState -> TVar [a]) -> IO a
getNextAvailable' SCServerState
serverState SCServerState -> TVar [a]
getter =
   SCServerState -> Int -> (SCServerState -> TVar [a]) -> IO [a]
forall a.
SCServerState -> Int -> (SCServerState -> TVar [a]) -> IO [a]
getNextAvailables' SCServerState
serverState Int
1 SCServerState -> TVar [a]
getter IO [a] -> ([a] -> IO a) -> IO a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      [a
x] -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
      [a]
_ -> String -> IO a
forall a. HasCallStack => String -> a
error String
"i don't even - 938"

getNextAvailables' :: SCServerState -> Int -> (SCServerState -> TVar [a]) -> IO [a]
getNextAvailables' :: SCServerState -> Int -> (SCServerState -> TVar [a]) -> IO [a]
getNextAvailables' SCServerState
serverState Int
numToGet SCServerState -> TVar [a]
getter = do
   let !SCServerState
_ = SCServerState
serverState
   STM [a] -> IO [a]
forall a. STM a -> IO a
atomically (STM [a] -> IO [a]) -> STM [a] -> IO [a]
forall a b. (a -> b) -> a -> b
$ do
      let avail :: TVar [a]
avail = SCServerState -> TVar [a]
getter SCServerState
serverState
      ([a]
ns, [a]
rest) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
numToGet ([a] -> ([a], [a])) -> STM [a] -> STM ([a], [a])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TVar [a] -> STM [a]
forall a. TVar a -> STM a
readTVar TVar [a]
avail
      TVar [a] -> [a] -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar [a]
avail [a]
rest
      [a] -> STM [a]
forall (m :: * -> *) a. Monad m => a -> m a
return [a]
ns