{-# LANGUAGE LambdaCase #-}
module Database.Liszt.Network
  ( startServer
  , Connection
  , withConnection
  , connect
  , disconnect
  , fetch) where

import Control.Concurrent
import Control.Exception
import Control.Monad
import Database.Liszt.Tracker
import Database.Liszt.Internal (hPayload, RawPointer(..))
import Data.Binary
import Data.Binary.Get
import Data.Winery
import qualified Data.Winery.Internal.Builder as WB
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as BL
import qualified Network.Socket.SendFile.Handle as SF
import qualified Network.Socket.ByteString as SB
import qualified Network.Socket as S
import System.IO
import Text.Read (readMaybe)

respond :: Tracker -> S.Socket -> IO ()
respond tracker conn = do
  msg <- SB.recv conn 4096
  req <- try (evaluate $ decodeCurrent msg) >>= \case
    Left e -> throwIO $ WineryError e
    Right a -> return a
  unless (B.null msg) $ handleRequest tracker req $ \lh lastSeqNo offsets -> do
    let count = length offsets
    _ <- SB.send conn $ encodeResp $ Right count
    forM_ (zip [lastSeqNo - count + 1..] offsets) $ \(i, (tag, RP pos len)) -> do
      SB.sendAll conn $ WB.toByteString $ mconcat
        [ WB.word64 (fromIntegral i)
        , WB.word64 (fromIntegral $ WB.getSize tag), tag
        , WB.word64 $ fromIntegral len]
      SF.sendFile' conn (hPayload lh) (fromIntegral pos) (fromIntegral len)

startServer :: Int -> FilePath -> IO ()
startServer port prefix = withLisztReader prefix $ \env -> do
  let hints = S.defaultHints { S.addrFlags = [S.AI_NUMERICHOST, S.AI_NUMERICSERV], S.addrSocketType = S.Stream }
  addr:_ <- S.getAddrInfo (Just hints) (Just "0.0.0.0") (Just $ show port)
  bracket (S.socket (S.addrFamily addr) (S.addrSocketType addr) (S.addrProtocol addr)) S.close $ \sock -> do
    S.setSocketOption sock S.ReuseAddr 1
    S.setSocketOption sock S.NoDelay 1
    S.bind sock $ S.SockAddrInet (fromIntegral port) (S.tupleToHostAddress (0,0,0,0))
    S.listen sock 2
    forever $ do
      (conn, _) <- S.accept sock
      forkFinally (do
        path <- decode .  BL.fromStrict <$> SB.recv conn 4096
        withTracker env path $ \t -> do
          SB.sendAll conn $ B.pack "READY"
          forever $ respond t conn)
        $ \result -> do
          case result of
            Left ex -> case fromException ex of
              Just e -> SB.sendAll conn $ encodeResp $ Left $ show (e :: LisztError)
              Nothing -> hPutStrLn stderr $ show ex
            Right _ -> return ()
          S.close conn

encodeResp :: Either String Int -> B.ByteString
encodeResp = BL.toStrict . encode

newtype Connection = Connection (MVar S.Socket)

withConnection :: String -> Int -> B.ByteString -> (Connection -> IO r) -> IO r
withConnection host port path = bracket (connect host port path) disconnect

connect :: String -> Int -> B.ByteString -> IO Connection
connect host port path = do
  let hints = S.defaultHints { S.addrFlags = [S.AI_NUMERICSERV], S.addrSocketType = S.Stream }
  addr:_ <- S.getAddrInfo (Just hints) (Just host) (Just $ show port)
  sock <- S.socket (S.addrFamily addr) (S.addrSocketType addr) (S.addrProtocol addr)
  S.connect sock $ S.addrAddress addr
  SB.sendAll sock $ BL.toStrict $ encode path
  _ <- SB.recv sock 4096
  Connection <$> newMVar sock

disconnect :: Connection -> IO ()
disconnect (Connection sock) = takeMVar sock >>= S.close

fetch :: Connection -> Request -> IO [(Int, B.ByteString, B.ByteString)]
fetch (Connection msock) req = modifyMVar msock $ \sock -> do
  SB.sendAll sock $ serialiseOnly req
  go sock $ runGetIncremental $ get >>= \case
    Left e -> case readMaybe e of
      Just e' -> throw (e' :: LisztError)
      Nothing -> fail $ "Unknown error: " ++ show e
    Right n -> replicateM n ((,,) <$> get <*> get <*> get)
  where
    go sock (Done _ _ a) = return (sock, a)
    go sock (Partial cont) = do
      bs <- SB.recv sock 4096
      if B.null bs then go sock $ cont Nothing else go sock $ cont $ Just bs
    go _ (Fail _ _ str) = fail $ show req ++ ": " ++ str