{-# LANGUAGE CPP, ScopedTypeVariables #-}

-- | A channel module with transparent network communication.
module Control.CUtils.NetChan (NetSend, NetRecv, localHost, newNetChan, newNetSend, newNetRecv, send, recv, recvSend, sendRecv, recvRecv, activateSend, activateRecv) where

-- This module has a strategy for routing around dead nodes. See 'routeAround'.

import System.IO
import System.Process
import Data.List (find, isPrefixOf, isInfixOf, (\\))
import Network
import Network.Socket (socketToHandle, SockAddr(..))
import Network.BSD
import Control.Concurrent
import Control.Monad
import Data.ByteString.Lazy (ByteString, hGet, hPut, length, fromChunks, append, empty)
import qualified Data.ByteString as B
import Data.Binary
import qualified Data.Map as M
import Data.Maybe
import Data.Char
import Data.IORef
import Data.Bits
import Control.Exception
import System.IO.Unsafe
import Prelude hiding (lookup, length, catch)

import Control.CUtils.Split

type Ident = ByteString

{-# NOINLINE serverup #-}
serverup = unsafePerformIO (newMVar False)

{-# NOINLINE table #-}
table :: MVar (M.Map Ident (ByteString -> IO ()))
table = unsafePerformIO (newMVar (M.singleton empty (\_ -> return ())))

data ChannelFibre t = ChannelFibre (MVar Bool) Handle

data NetSend t = NetSend HostName Ident (MVar [HostName]) (MVar [ChannelFibre t])

data NetRecv t = NetRecv Ident (NetSend t) (NetSend HostName) (Chan t)

instance Eq (ChannelFibre t) where
	ChannelFibre _ hdl == ChannelFibre _ hdl2 = hdl == hdl2

instance Eq (NetSend t) where
	NetSend _ ident _ _ == NetSend _ ident2 _ _ = ident == ident2

instance Eq (NetRecv t) where
	NetRecv ident _ _ _ == NetRecv ident2 _ _ _ = ident == ident2

port = 2999

getIPAddress :: String -> Word32
getIPAddress ip = shiftL n4 24 .|. shiftL n3 16 .|. shiftL n2 8 .|. n1 where
	[n1,n2,n3,n4] = map read $ split '.' ip

-- Hack - just gets the local IP address
localHost = liftM (drop 39 . head . dropWhile (not . isPrefixOf "   IPv4") . lines) $ readProcess "ipconfig" [] []

-- The identifier of a channel is determined by the originating host and a host-unique serial number.
identifier :: String -> Word32 -> Ident
identifier ip entry = encode (entry, getIPAddress ip)

--- Channel creation.

-- | Creates a new channel, with receive and send ends.
newNetChan :: (Binary t) => IO (NetRecv t, NetSend t)
newNetChan = do
	mp <- readMVar table
	host <- localHost
	let ident = identifier host (fromIntegral (M.size mp))
	liftM2 (,) (__newNetRecv True Nothing ident) (__newNetSend True host ident)

modifyIdent b ident = append (fromChunks [B.pack $ map (fromIntegral . ord) $ if b then "main" else "back"]) ident

__emptyNetSend :: Bool -> NetSend HostName -> HostName -> Ident -> IO (NetSend t)
__emptyNetSend b backDown hostName ident = do
	let ident' = modifyIdent b ident

	-- Create a back channel.
	buffer <- newMVar []
	-- Fill the buffer immediately, so this host gets the data before downstreams die.
	if b then do
			backR <- __newNetRecv False (Just backDown) ident
			let loop = do
				host <- recv backR
				modifyMVar_ buffer (return . (host:))

				loop
			forkIO loop
		else
			return undefined

	mvar <- newMVar []
	return (NetSend hostName ident' buffer mvar)

__addConnection s@(NetSend _ ident buffer mvar) hostName = do
	mvar2 <- newMVar False

	-- Open a TCPIP socket to send
	hdl <- withSocketsDo $ connectTo hostName (PortNumber port)
	hSetBuffering hdl (BlockBuffering (Just 1024))

	-- Send identifier
	hPut hdl ident

	-- Send list of upstreams
	upstreams <- readMVar buffer
	let bs = encode (hostName : upstreams)
	hPut hdl $ encode $ length bs
	hPut hdl bs

	hFlush hdl

	modifyMVar_ mvar (return . (ChannelFibre mvar2 hdl:))

__newNetSend b hostName ident = do
	s <- if b then
			__emptyNetSend False undefined "" ident
		else
			return undefined
	s <- __emptyNetSend b s hostName ident
	__addConnection s hostName
	return s

-- | Open a channel to another host
newNetSend hostName = __newNetSend True hostName (identifier hostName 0)

readLoop f hdl = do
	n <- liftM decode (hGet hdl 8)
	bs <- hGet hdl n
	f bs
	readLoop f hdl

server socket = withSocketsDo $ do
	-- Accept loop
	let loop = do
		(hdl, host, _) <- accept socket
		ident <- hGet hdl 12
		may <- liftM (M.lookup ident) $ readMVar table
		maybe
			(hPutStrLn stderr ("The host " ++ host ++ " used an invalid Ident: " ++ show ident))
			(\f -> forkIO (withSocketsDo (readLoop f hdl)) >> return ())
			may
		loop

	loop

__newNetRecv :: (Binary t) => Bool -> Maybe (NetSend t) -> Ident -> IO (NetRecv t)
__newNetRecv b may ident = do
	chan <- newChan

	-- Create a back channel
	--
	-- The downstream of the back channel is the upstream of the main channel.
	backS <- if b then
			__emptyNetSend False undefined "" ident
		else
			return undefined

	downstream <- maybe
		(__emptyNetSend b backS "" ident)
		return
		may

	let ident' = modifyIdent b ident

	gotUpstreams <- newIORef False
	let listener bs = do
		got <- readIORef gotUpstreams
		if got then do
				let x = decode bs
				writeChan chan x

				-- Send the value to downstream receive ends.
				send downstream x
			else do
				writeIORef gotUpstreams True
				let x:xs = decode bs
				when b $ do
					let NetSend _ _ buffer _ = backS
					modifyMVar_ buffer (\_ -> return xs)
					__addConnection backS x

	-- Put a listener in the table.
	modifyMVar_ table (return . M.insert ident' listener)

	-- Start the server singleton
	modifyMVar_ serverup (\b -> unless b (withSocketsDo $ listenOn (PortNumber port) >>= forkIO . server >> return ()) >> return True)

	return (NetRecv ident' downstream backS chan)

-- | Creates a receive end of this host's channel. Type unsafe!
newNetRecv :: (Binary t) => IO (NetRecv t)
newNetRecv = localHost >>= \host -> __newNetRecv True Nothing (identifier host 0)

--- Send and receive.

-- If send fails, route around the node.
routeAround fib s@(NetSend _ ident buffer mvar) = do
	hosts <- modifyMVar buffer (\ls -> return ([], ls))
	mapM_ (__addConnection s) hosts
	modifyMVar_ mvar (return . (\\[fib]))

-- | Sends something on a channel.
send :: (Binary t) => NetSend t -> t -> IO ()
send snd@(NetSend _ ident _ mvar) x = readMVar mvar >>= mapM_ (\fib@(ChannelFibre mvar hdl) -> do
	b <- modifyMVar mvar (\b -> let s = encode x in
		s `seq` catch (hPut hdl (encode (length s)) >> hPut hdl s) (\(_ :: SomeException) -> routeAround fib snd >> send snd x)
		>> return (True, b))
	-- Buffering
	unless b $ void $ forkIO $ do
		threadDelay 100000
		modifyMVar_ mvar (\_ -> return False)
		catch (hFlush hdl) (\(_ :: SomeException) -> routeAround fib snd >> send snd x))

-- | Receives something from a channel.
recv (NetRecv _ _ _ chan) = readChan chan

--- Sending and receiving channels.

-- | Receives the send end of a channel, on a channel.
recvSend r = recv r >>= activateSend

-- | Sends the receive end of a channel, on a channel.
sendRecv s@(NetSend hostName _ _ mvar) r@(NetRecv ident s2 backS _) = do
	send s r

	-- This node is now responsible for passing on messages to the destination(s).
	__addConnection s2 hostName

	-- Inform upstream of this
	send backS hostName

-- | Receives the receive end of a channel, on a channel.
recvRecv r = recv r >>= activateRecv

--- Channel data utilities.

instance Binary (NetSend t) where
	put (NetSend hostName ident _ _) = put hostName >> put ident
	get = liftM2 (\x y -> NetSend x y undefined undefined) get get

instance Binary (NetRecv t) where
	put (NetRecv ident _ _ _) = put ident
	get = liftM (\x -> NetRecv x undefined undefined undefined) get

-- | 'get' produces channel ends with some data missing. Use these to make them usable.
activateSend :: NetSend t -> IO (NetSend t)
activateSend (NetSend hostName ident _ _) = __newNetSend True hostName ident

activateRecv :: (Binary t) => NetRecv t -> IO (NetRecv t)
activateRecv (NetRecv x _ _ _) = __newNetRecv True Nothing x