{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards #-}
module Neovim.RPC.SocketReader (
runSocketReader,
parseParams,
) where
import Neovim.Classes
import Neovim.Context
import qualified Neovim.Context.Internal as Internal
import Neovim.Plugin.Classes (CommandArguments (..),
CommandOption (..),
FunctionName (..),
NvimMethod (..),
FunctionalityDescription (..),
getCommandOptions)
import Neovim.Plugin.IPC.Classes
import qualified Neovim.RPC.Classes as MsgpackRPC
import Neovim.RPC.Common
import Neovim.RPC.FunctionCall
import Control.Applicative
import Control.Concurrent.STM
import Control.Monad (void)
import Control.Monad.Trans.Class (lift)
import Conduit as C
import Data.Conduit.Cereal (conduitGet2)
import Data.Default (def)
import Data.Foldable (foldl', forM_)
import qualified Data.Map as Map
import Data.MessagePack
import Data.Monoid
import qualified Data.Serialize (get)
import System.IO (Handle)
import System.Log.Logger
import UnliftIO.Async (async, race)
import UnliftIO.Concurrent (threadDelay)
import Prelude
logger :: String
logger = "Socket Reader"
type SocketHandler = Neovim RPCConfig
runSocketReader :: Handle
-> Internal.Config RPCConfig
-> IO ()
runSocketReader readableHandle cfg =
void . runNeovim (Internal.retypeConfig (Internal.customConfig cfg) cfg) . runConduit $ do
sourceHandle readableHandle
.| conduitGet2 Data.Serialize.get
.| messageHandlerSink
messageHandlerSink :: ConduitT Object Void SocketHandler ()
messageHandlerSink = awaitForever $ \rpc -> do
liftIO . debugM logger $ "Received: " <> show rpc
case fromObject rpc of
Right (MsgpackRPC.Request (Request fn i ps)) ->
handleRequestOrNotification (Just i) fn ps
Right (MsgpackRPC.Response i r) ->
handleResponse i r
Right (MsgpackRPC.Notification (Notification fn ps)) ->
handleRequestOrNotification Nothing fn ps
Left e -> liftIO . errorM logger $
"Unhandled rpc message: " <> show e
handleResponse :: Int64 -> Either Object Object
-> ConduitT a Void SocketHandler ()
handleResponse i result = do
answerMap <- asks recipients
mReply <- Map.lookup i <$> liftIO (readTVarIO answerMap)
case mReply of
Nothing -> liftIO $ warningM logger
"Received response but could not find a matching recipient."
Just (_,reply) -> do
atomically' . modifyTVar' answerMap $ Map.delete i
atomically' $ putTMVar reply result
handleRequestOrNotification :: Maybe Int64 -> FunctionName -> [Object]
-> ConduitT a Void SocketHandler ()
handleRequestOrNotification requestId functionToCall@(F functionName) params = do
cfg <- lift Internal.ask'
void . liftIO . async $ race logTimeout (handle cfg)
return ()
where
lookupFunction
:: TMVar Internal.FunctionMap
-> STM (Maybe (FunctionalityDescription, Internal.FunctionType))
lookupFunction funMap = Map.lookup (NvimMethod functionName) <$> readTMVar funMap
logTimeout :: IO ()
logTimeout = do
let seconds = 1000 * 1000
threadDelay (10 * seconds)
debugM logger $ "Cancelled another action before it was finished"
handle :: Internal.Config RPCConfig -> IO ()
handle rpc = atomically (lookupFunction (Internal.globalFunctionMap rpc)) >>= \case
Nothing -> do
let errM = "No provider for: " <> show functionToCall
debugM logger errM
forM_ requestId $ \i -> writeMessage (Internal.eventQueue rpc) $
MsgpackRPC.Response i (Left (toObject errM))
Just (copts, Internal.Stateful c) -> do
now <- liftIO getCurrentTime
reply <- liftIO newEmptyTMVarIO
let q = (recipients . Internal.customConfig) rpc
liftIO . debugM logger $ "Executing stateful function with ID: " <> show requestId
case requestId of
Just i -> do
atomically' . modifyTVar q $ Map.insert i (now, reply)
writeMessage c $ Request functionToCall i (parseParams copts params)
Nothing ->
writeMessage c $ Notification functionToCall (parseParams copts params)
parseParams :: FunctionalityDescription -> [Object] -> [Object]
parseParams (Function _ _) args = case args of
[ObjectArray fArgs] -> fArgs
_ -> args
parseParams cmd@(Command _ opts) args = case args of
(ObjectArray _ : _) ->
let cmdArgs = filter isPassedViaRPC (getCommandOptions opts)
(c,args') =
foldl' createCommandArguments (def, []) $
zip cmdArgs args
in toObject c : args'
_ -> parseParams cmd $ [ObjectArray args]
where
isPassedViaRPC :: CommandOption -> Bool
isPassedViaRPC = \case
CmdSync{} -> False
_ -> True
createCommandArguments :: (CommandArguments, [Object])
-> (CommandOption, Object)
-> (CommandArguments, [Object])
createCommandArguments old@(c, args') = \case
(CmdRange _, o) ->
either (const old) (\r -> (c { range = Just r }, args')) $ fromObject o
(CmdCount _, o) ->
either (const old) (\n -> (c { count = Just n }, args')) $ fromObject o
(CmdBang, o) ->
either (const old) (\b -> (c { bang = Just b }, args')) $ fromObject o
(CmdNargs "*", ObjectArray os) ->
(c, os)
(CmdNargs "+", ObjectArray (o:os)) ->
(c, o : [ObjectArray os])
(CmdNargs "?", ObjectArray [o]) ->
(c, [toObject (Just o)])
(CmdNargs "?", ObjectArray []) ->
(c, [toObject (Nothing :: Maybe Object)])
(CmdNargs "0", ObjectArray []) ->
(c, [])
(CmdNargs "1", ObjectArray [o]) ->
(c, [o])
(CmdRegister, o) ->
either (const old) (\r -> (c { register = Just r }, args')) $ fromObject o
_ -> old
parseParams (Autocmd _ _ _ _) args = case args of
[ObjectArray fArgs] -> fArgs
_ -> args