-- |
-- Stability   :  Ultra-Violence
-- Portability :  I'm too young to die
-- Listening on sockets for the incoming requests.
{-# LANGUAGE OverloadedStrings, ScopedTypeVariables #-}

module Network.NineP.Server
	( module Network.NineP.Internal.File
	, Config(..)
	, run9PServer
	) where

import Control.Concurrent
import Control.Concurrent.MState hiding (get, put)
import Control.Exception (assert)
import Control.Monad
import Control.Monad.Catch
import Control.Monad.EmbedIO
import Control.Monad.Loops
import Control.Monad.Reader
import Control.Monad.Trans
import Data.Binary.Get
import Data.Binary.Put
import Data.Bits
import qualified Data.ByteString as BS
import Data.ByteString.Lazy.Char8 (ByteString)
import qualified Data.ByteString.Lazy.Char8 as B
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.NineP
import Data.Word
import Network.BSD
import Network.Socket hiding (send, sendTo, recv, recvFrom)
import System.IO
import System.Log.Logger
import Text.Regex.Posix ((=~))

import Network.NineP.Error
import Network.NineP.Internal.File
import Network.NineP.Internal.Msg
import Network.NineP.Internal.State

maybeRead :: Read a => String -> Maybe a
maybeRead :: forall a. Read a => String -> Maybe a
maybeRead = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> Maybe a
listToMaybe forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Read a => ReadS a
reads

listenOn :: SockAddr -> IO Socket
listenOn SockAddr
addr = do
	Socket
sock <- Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_UNIX SocketType
Stream ProtocolNumber
defaultProtocol
	Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
addr
	Socket -> Int -> IO ()
listen Socket
sock Int
5
	forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock

connection :: String -> IO Socket
connection :: String -> IO Socket
connection String
s = let	pat :: ByteString
pat = ByteString
"tcp!(.*)!([0-9]*)|unix!(.*)" :: ByteString
			wrongAddr :: IO a
wrongAddr = forall a. IOError -> IO a
ioError forall a b. (a -> b) -> a -> b
$ String -> IOError
userError forall a b. (a -> b) -> a -> b
$ String
"wrong 9p connection address: " forall a. [a] -> [a] -> [a]
++ String
s
			(String
bef, String
_, String
aft, [String]
grps) = String
s forall source source1 target.
(RegexMaker Regex CompOption ExecOption source,
 RegexContext Regex source1 target) =>
source1 -> source -> target
=~ ByteString
pat :: (String, String, String, [String])
	in if (String
bef forall a. Eq a => a -> a -> Bool
/= String
"" Bool -> Bool -> Bool
|| String
aft forall a. Eq a => a -> a -> Bool
/= String
"" Bool -> Bool -> Bool
|| [String]
grps forall a. Eq a => a -> a -> Bool
== [])
		then forall {a}. IO a
wrongAddr
		else case [String]
grps of
			[String
addr, String
port, String
""] -> String -> PortNumber -> IO Socket
listen' String
addr forall a b. (a -> b) -> a -> b
$ forall a. Enum a => Int -> a
toEnum forall a b. (a -> b) -> a -> b
$ (forall a. a -> Maybe a -> a
fromMaybe Int
2358 forall a b. (a -> b) -> a -> b
$ forall a. Read a => String -> Maybe a
maybeRead String
port :: Int)
			[String
"", String
"", String
addr]  -> SockAddr -> IO Socket
listenOn forall a b. (a -> b) -> a -> b
$ String -> SockAddr
SockAddrUnix String
addr
			[String]
_ -> forall {a}. IO a
wrongAddr

listen' :: HostName -> PortNumber -> IO Socket
listen' :: String -> PortNumber -> IO Socket
listen' String
hostname PortNumber
port = do
	ProtocolNumber
proto <- String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
	forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
AF_INET SocketType
Stream ProtocolNumber
proto) Socket -> IO ()
close (\Socket
sock -> do
		Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
		HostEntry
he <- String -> IO HostEntry
getHostByName String
hostname
		Socket -> SockAddr -> IO ()
bind Socket
sock (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port (HostEntry -> HostAddress
hostAddress HostEntry
he))
		Socket -> Int -> IO ()
listen Socket
sock Int
maxListenQueue
		forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock)

-- |Run the actual server using the supplied configuration.
run9PServer :: (EmbedIO m) => Config m -> IO ()
run9PServer :: forall (m :: * -> *). EmbedIO m => Config m -> IO ()
run9PServer Config m
cfg = do
	Socket
s <- String -> IO Socket
connection forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Config m -> String
addr Config m
cfg
	forall (m :: * -> *). EmbedIO m => Socket -> Config m -> IO ()
serve Socket
s Config m
cfg

serve :: (EmbedIO m) => Socket -> Config m -> IO ()
serve :: forall (m :: * -> *). EmbedIO m => Socket -> Config m -> IO ()
serve Socket
s Config m
cfg = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ Socket -> IO (Socket, SockAddr)
accept Socket
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (
		\(Socket
s, SockAddr
_) -> (forall (m :: * -> *). EmbedIO m => Config m -> Handle -> IO ()
doClient Config m
cfg) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Socket -> IOMode -> IO Handle
socketToHandle Socket
s IOMode
ReadWriteMode))

doClient :: (EmbedIO m) => Config m -> Handle -> IO ()
doClient :: forall (m :: * -> *). EmbedIO m => Config m -> Handle -> IO ()
doClient Config m
cfg Handle
h = do
	Handle -> BufferMode -> IO ()
hSetBuffering Handle
h BufferMode
NoBuffering
	Chan Msg
chan <- (forall a. IO (Chan a)
newChan :: IO (Chan Msg))
	ThreadId
st <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ IO Msg -> (ByteString -> IO ()) -> IO ()
sender (forall a. Chan a -> IO a
readChan Chan Msg
chan) (Handle -> ByteString -> IO ()
BS.hPut Handle
h forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
BS.concat forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
B.toChunks) -- make a strict bytestring
	forall (m :: * -> *).
EmbedIO m =>
Config m -> Handle -> (Msg -> IO ()) -> IO ()
receiver Config m
cfg Handle
h (forall a. Chan a -> a -> IO ()
writeChan Chan Msg
chan)
	ThreadId -> IO ()
killThread ThreadId
st
	Handle -> IO ()
hClose Handle
h

recvPacket :: Handle -> IO Msg
recvPacket :: Handle -> IO Msg
recvPacket Handle
h = do
	-- TODO error reporting
	ByteString
s <- Handle -> Int -> IO ByteString
B.hGet Handle
h Int
4
	let l :: Int
l = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. Get a -> ByteString -> a
runGet Get HostAddress
getWord32le forall a b. (a -> b) -> a -> b
$ forall a. (?callStack::CallStack) => Bool -> a -> a
assert (ByteString -> Int64
B.length ByteString
s forall a. Eq a => a -> a -> Bool
== Int64
4) ByteString
s
	ByteString
p <- Handle -> Int -> IO ByteString
B.hGet Handle
h forall a b. (a -> b) -> a -> b
$ Int
l forall a. Num a => a -> a -> a
- Int
4
	let m :: Msg
m = forall a. Get a -> ByteString -> a
runGet (forall a. Bin a => Get a
get :: Get Msg) (ByteString -> ByteString -> ByteString
B.append ByteString
s ByteString
p)
	String -> String -> IO ()
debugM String
"Network.NineP.Server" forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Msg
m
	forall (m :: * -> *) a. Monad m => a -> m a
return Msg
m

sender :: IO Msg -> (ByteString -> IO ()) -> IO ()
sender :: IO Msg -> (ByteString -> IO ()) -> IO ()
sender IO Msg
get ByteString -> IO ()
say = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
	Msg
msg <- IO Msg
get
	String -> String -> IO ()
debugM String
"Network.NineP.Server" forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show Msg
msg
	ByteString -> IO ()
say forall a b. (a -> b) -> a -> b
$ Put -> ByteString
runPut forall a b. (a -> b) -> a -> b
$ forall a. Bin a => a -> Put
put Msg
msg

receiver :: (EmbedIO m) => Config m -> Handle -> (Msg -> IO ()) -> IO ()
receiver :: forall (m :: * -> *).
EmbedIO m =>
Config m -> Handle -> (Msg -> IO ()) -> IO ()
receiver Config m
cfg Handle
h Msg -> IO ()
say = forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (forall (m :: * -> *) t a.
MonadPeelIO m =>
MState t m a -> t -> m (a, t)
runMState (forall (m :: * -> *) a. Monad m => (a -> Bool) -> m a -> m a
iterateUntil forall a. a -> a
id (do
			Either SomeException Msg
mp <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try forall a b. (a -> b) -> a -> b
$ Handle -> IO Msg
recvPacket Handle
h
			case Either SomeException Msg
mp of
				Left (SomeException
e :: SomeException) -> do
					forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Network.NineP.Server" forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show SomeException
e
					forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
				Right Msg
p -> do
					forall (m :: * -> *) t.
MonadPeelIO m =>
MState t m () -> MState t m ThreadId
forkM forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
EmbedIO m =>
(Msg -> IO ())
-> Msg -> MState (NineState m) (ReaderT (Config m) IO) ()
handleMsg Msg -> IO ()
say Msg
p
					forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
		) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ()
	) (forall {m :: * -> *}. Content m -> NineState m
emptyState forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Config m -> Content m
monadState Config m
cfg)) Config m
cfg forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ()

handleMsg :: (EmbedIO m) => (Msg -> IO ()) -> Msg -> MState (NineState m) (ReaderT (Config m) IO) ()
handleMsg :: forall (m :: * -> *).
EmbedIO m =>
(Msg -> IO ())
-> Msg -> MState (NineState m) (ReaderT (Config m) IO) ()
handleMsg Msg -> IO ()
say Msg
p = do
	let Msg Tag
typ Word16
t VarMsg
m = Msg
p
	Either SomeException [Msg]
r <- forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (case Tag
typ of
			Tag
TTversion -> forall (m :: * -> *). Msg -> Nine m [Msg]
rversion Msg
p
			Tag
TTattach -> forall {m1 :: * -> *} {m2 :: * -> *}.
(EmbedIO m1, Monad m2) =>
Msg -> MState (NineState m1) (ReaderT (Config m1) IO) (m2 Msg)
rattach Msg
p
			Tag
TTwalk -> forall {m1 :: * -> *} {m2 :: * -> *}.
(EmbedIO m1, Monad m2) =>
Msg -> MState (NineState m1) (ReaderT (Config m1) IO) (m2 Msg)
rwalk Msg
p
			Tag
TTstat -> forall {m1 :: * -> *} {m2 :: * -> *}.
(EmbedIO m1, Monad m2) =>
Msg -> MState (NineState m1) (ReaderT (Config m1) IO) (m2 Msg)
rstat Msg
p
			Tag
TTwstat -> forall {m1 :: * -> *} {m2 :: * -> *}.
Monad m1 =>
Msg -> MState (NineState m2) (ReaderT (Config m2) IO) (m1 Msg)
rwstat Msg
p
			Tag
TTclunk -> forall {m1 :: * -> *} {m2 :: * -> *}.
Monad m1 =>
Msg -> MState (NineState m2) (ReaderT (Config m2) IO) (m1 Msg)
rclunk Msg
p
			Tag
TTauth -> forall {a}. Msg -> a
rauth Msg
p
			Tag
TTopen -> forall {m1 :: * -> *} {m2 :: * -> *}.
(EmbedIO m1, Monad m2) =>
Msg -> MState (NineState m1) (ReaderT (Config m1) IO) (m2 Msg)
ropen Msg
p
			Tag
TTread -> forall (m :: * -> *). (Monad m, EmbedIO m) => Msg -> Nine m [Msg]
rread Msg
p
			Tag
TTwrite -> forall {m1 :: * -> *} {m2 :: * -> *}.
(EmbedIO m1, Monad m2) =>
Msg -> MState (NineState m1) (ReaderT (Config m1) IO) (m2 Msg)
rwrite Msg
p
			Tag
TTremove -> forall {m :: * -> *} {b}.
Msg -> MState (NineState m) (ReaderT (Config m) IO) b
rremove Msg
p
			Tag
TTcreate -> forall {m1 :: * -> *} {m2 :: * -> *}.
(EmbedIO m1, Monad m2) =>
Msg -> MState (NineState m1) (ReaderT (Config m1) IO) (m2 Msg)
rcreate Msg
p
			Tag
TTflush -> forall {m1 :: * -> *} {m2 :: * -> *}.
(Monad m1, Monad m2) =>
Msg -> m1 (m2 Msg)
rflush Msg
p
		)
	case Either SomeException [Msg]
r of
		(Right [Msg]
response) -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Msg -> IO ()
say forall a b. (a -> b) -> a -> b
$ [Msg]
response
		-- FIXME which exceptions should i catch?
		(Left SomeException
fail) -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Msg -> IO ()
say forall a b. (a -> b) -> a -> b
$ Tag -> Word16 -> VarMsg -> Msg
Msg Tag
TRerror Word16
t forall a b. (a -> b) -> a -> b
$ String -> VarMsg
Rerror forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ (SomeException
fail :: SomeException)