{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} module Database.TDS.Connection ( newConnection ) where import Database.TDS.Types import Database.TDS.Proto import Control.Concurrent import Control.Concurrent.STM import Control.Exception ( SomeException, IOException , Exception, throwIO, bracket , catch, finally, onException ) import Control.Exception (mask) import Control.Monad.IO.Class import Control.Monad.Identity import Control.Monad.Trans import Data.Bits import qualified Data.ByteString as BS import qualified Data.ByteString.Internal as IBS import qualified Data.ByteString.Streaming as SBS import Data.Char import Data.Maybe (fromMaybe) import Data.Monoid ((<>)) import Data.Proxy import Data.Text (Text) import Data.Word import Foreign.C.Types import Foreign.ForeignPtr import Foreign.Marshal.Alloc import Foreign.Ptr import Foreign.Storable import qualified Network.Socket as BSD import qualified Streaming as S import qualified Streaming.Prelude as S type family IsCancelable (resp :: ResponseInfo a) :: Bool where IsCancelable ('ExpectsResponse ('ResponseType r a)) = r IsCancelable _ = 'False data SomeSentPacket where SomeSentPacket :: Show d => Packet sender resp d Identity -> CancelInfo (IsCancelable resp) -> SomeSentPacket data TDSCanceledException = TDSCanceled deriving Show instance Exception TDSCanceledException data TDSWasCanceledException = TDSWasCanceled deriving Show instance Exception TDSWasCanceledException data WriteEnd = WriteEnd !BSD.Socket -- Output buffer !(Maybe (ForeignPtr (), CSize)) data ReadEnd = ReadEnd !BSD.Socket data ConnectionState = ConnectionQuit | ConnectionState { connectionWriteSock :: TVar (Either ThreadId WriteEnd) , connectionReadSock :: TVar (Either ThreadId ReadEnd) , connectionCurReq :: TVar (Maybe SomeSentPacket) , connectionCurState :: TVar ClientState -- Information on current database, locale, etc , connectionEnvironment :: !ClientEnv } closeWriteEnd :: WriteEnd -> IO () closeWriteEnd (WriteEnd _ _) = pure () closeReadEnd :: ReadEnd -> IO () closeReadEnd (ReadEnd s) = BSD.close s -- | Start a new connection to a TDS server newConnection :: Options -> IO Connection newConnection opts = do -- TODO UNIX, pipe, or other transport possibilities let opts' = defaultOptions <> opts connInfo = _tdsConnInfo opts' addrInfo = BSD.defaultHints { BSD.addrSocketType = BSD.Stream } addrs <- BSD.getAddrInfo (Just addrInfo) (_tdsConnHost connInfo) (show <$> _tdsConnPort connInfo) `catch` \(e :: IOException) -> throwIO (tdsErrorNoReq TDSNoSuchHost Connecting (show e)) case addrs of [] -> throwIO (tdsErrorNoReq TDSNoSuchHost Connecting "No addresses returned by getAddrInfo") addr:_ -> do sock <- BSD.socket (BSD.addrFamily addr) (BSD.addrSocketType addr) (BSD.addrProtocol addr) `catch` \(e :: IOException) -> throwIO (tdsErrorNoReq TDSSocketError Connecting (show e)) BSD.connect sock (BSD.addrAddress addr) `catch` (\(e :: IOException) -> throwIO (tdsErrorNoReq TDSSocketError Connecting (show e))) `onException` BSD.close sock writeV <- newTVarIO (Right (WriteEnd sock Nothing)) readV <- newTVarIO (Right (ReadEnd sock)) reqV <- newTVarIO Nothing stV <- newTVarIO Connecting let connSt = ConnectionState writeV readV reqV stV (clientEnvFromOptions opts) connStV <- newTVarIO connSt pure (Connection (sendPacket opts connStV) (cancel connStV) (quit connStV) stV opts) debugPtr :: Ptr a -> CSize -> IO () debugPtr _ _ = pure () -- debugPtr ptr' sz -- | sz > width = -- do putStrLn . foldMap formatHex =<< forM [0..width - 1] (peek . plusPtr ptr) -- debugPtr (ptr' `plusPtr` width) (sz - width) -- | otherwise = -- putStrLn . foldMap formatHex =<< forM [0..fromIntegral sz - 1] (peek . plusPtr ptr) -- -- where -- width :: Num a => a -- width = 8 -- -- ptr = castPtr ptr' -- -- formatHex :: Word8 -> String -- formatHex b = -- let hi = fromIntegral ((b `shiftR` 4) .&. 0xF) -- lo = fromIntegral (b .&. 0xF) -- in intToDigit hi : intToDigit lo : ' ' : [] sendPackets :: BSD.Socket -> ForeignPtr () -> CSize -> SplitPacket 'Client resp d -> IO () sendPackets sock fPtr bufSz = withForeignPtr fPtr . go where go pkt ptr = case pktData pkt of LastPacket writePkt -> do pktSz <- writePkt (ptr `plusPtr` fromIntegral pktHdrSz) let totalSz = fromIntegral (pktSz + pktHdrSz) let hdr = pktHdr pkt writeHdr ptr (hdr { pktHdrStatus = pktHdrStatus hdr <> pktStatusEndOfMessage , pktHdrLength = fromIntegral totalSz }) debugPtr ptr totalSz BSD.sendBuf sock (castPtr ptr) (fromIntegral totalSz) `catch` \(e :: SomeException) -> do putStrLn ("Exception " ++ show e) error "Bad" pure () OnePacket writePkt -> do pkt' <- writePkt (ptr `plusPtr` fromIntegral pktHdrSz) let hdr = pktHdr pkt writeHdr ptr (hdr { pktHdrLength = fromIntegral bufSz }) debugPtr ptr bufSz BSD.sendBuf sock (castPtr ptr) (fromIntegral bufSz) go pkt' ptr recvExactly :: BSD.Socket -> Ptr a -> Word16 -> IO () recvExactly sock p sz = do recvd <- fromIntegral <$> BSD.recvBuf sock (castPtr p) (fromIntegral sz) if recvd == sz then pure () else recvExactly sock (p `plusPtr` fromIntegral recvd) (sz - recvd) dataStream :: ReadEnd -> CancelInfo resp -> SBS.ByteString IO () dataStream (ReadEnd sock) cancel = do hdr <- liftIO . alloca $ \hdrP -> recvExactly sock (hdrP :: Ptr Word64) 8 >> readHdr (TabularResult :: PacketType 'Server 'NoResponse ()) hdrP case hdr of Nothing -> fail "Invalid header" Just hdr' -> do let pktLength = pktHdrLength hdr' - 8 bufSz = 65536 - 8 getChunk :: Word16 -> SBS.ByteString IO () getChunk 0 = pure () getChunk len = do case cancel of Cancelable isCanceled sync -> do cancelSt <- lift (atomically (readTVar isCanceled)) when cancelSt (liftIO (throwIO TDSWasCanceled)) _ -> pure () chunk <- liftIO $ bracket startRead (\_ -> endRead) $ \_ -> do fPtr <- mallocForeignPtrBytes (fromIntegral len) actuallyRead <- withForeignPtr fPtr $ \ptr -> do bytesRead <- BSD.recvBuf sock ptr (fromIntegral len) debugPtr ptr (fromIntegral bytesRead) pure bytesRead pure (IBS.fromForeignPtr fPtr 0 actuallyRead) SBS.chunk chunk getChunk (len - fromIntegral (BS.length chunk)) (startRead, endRead) = case cancel of Cancelable _ sync -> (atomically $ takeTMVar sync, atomically $ putTMVar sync ()) _ -> (pure (), pure ()) liftIO endRead getChunk pktLength if pktHdrStatus hdr' `hasStatus` pktStatusEndOfMessage then pure () else dataStream (ReadEnd sock) cancel withConnectionState :: TVar ConnectionState -> STM a -> (ConnectionState -> STM a) -> STM a withConnectionState stV onQuit onState = do connSt <- readTVar stV case connSt of ConnectionQuit -> onQuit ConnectionState {} -> onState connSt waitUntilSendable :: ThreadId -> ConnectionState -> (ClientState -> Bool) -> STM (Either TDSError (WriteEnd, ReadEnd)) waitUntilSendable _ ConnectionQuit _ = retry waitUntilSendable threadId (ConnectionState { connectionWriteSock = writeEndV , connectionReadSock = readEndV , connectionCurState = stateV , connectionCurReq = reqV }) canSendInState = do state <- readTVar stateV case state of _ | canSendInState state -> do writeSock <- either (\_ -> retry) pure =<< readTVar writeEndV readSock <- either (\_ -> retry) pure =<< readTVar readEndV maybe (pure ()) (\_ -> retry) =<< readTVar reqV writeTVar writeEndV (Left threadId) writeTVar readEndV (Left threadId) pure (Right (writeSock, readSock)) SentClientRequest -> tdsError TDSServerBusy SentClientRequest "Can't send request while server is still processing" SentAttention -> retry Final -> tdsError TDSServerQuit Final "Connection is closing" _ -> tdsError TDSServerUninitialized Final "The connection is not yet ready" where tdsError ty st msg = pure (Left (tdsErrorNoReq ty st msg)) surrenderWrite :: TVar ConnectionState -> WriteEnd -> STM () surrenderWrite stV we = do st <- readTVar stV case st of ConnectionQuit -> pure () ConnectionState { connectionWriteSock = writeEndV } -> do we' <- readTVar writeEndV case we' of Left {} -> writeTVar writeEndV (Right we) -- TODO output warning or something _ -> pure () sendPacket :: forall cancelable r d . ( Payload d, Response r, MkCancelable cancelable , KnownBool cancelable ) => Options -> TVar ConnectionState -> Packet 'Client ('ExpectsResponse ('ResponseType cancelable r)) d Identity -> IO (IO (ResponseResult ('ResponseType cancelable r))) sendPacket options stV pkt = myThreadId >>= go where go threadId = mask $ \unmask -> join . atomically . withConnectionState stV (pure (throwIO (tdsErrorNoReq TDSServerQuit Final "Can't send request to closed connection"))) $ \st -> do sock <- waitUntilSendable threadId st (sendableInState (pktHdrType (pktHdr pkt))) case sock of Left err -> pure (throwIO err) Right (writeEnd, readEnd) -> do cancel <- mkCancelable writeTVar (connectionCurReq st) (Just (SomeSentPacket pkt cancel)) pure (unmask $ let go = doSend writeEnd readEnd cancel go' = if boolVal (Proxy :: Proxy cancelable) then go `catch` \(e :: TDSCanceledException) -> cancelRequest st writeEnd >> throwIO e else go internalError = join . atomically $ quitSTM' stV writeEnd readEnd in go' `onException` internalError) encoding = packetEncoding pkt doSend we@(WriteEnd sock buf) readEnd cancel = do let (maxSz, splitEncoding) = splitPacket (maybe maximumPayloadPacketSize snd buf - pktHdrSz) encoding (buf', sz') <- case buf of Nothing -> case maxSz of Nothing -> fail "Don't know what size of buffer" Just maxSz' -> (, maxSz') <$> mallocForeignPtrBytes (fromIntegral (maxSz' + pktHdrSz)) Just (buf, sz) -> pure (buf, sz) sendPackets sock buf' sz' splitEncoding join . atomically . withConnectionState stV (pure (throwIO (TDSError TDSInvalidStateTransition Final (Just pkt) "Server quit before response received"))) $ \st -> do surrenderWrite stV we oldSt <- readTVar (connectionCurState st) let st' = stateTransition (pktHdrType (pktHdr pkt)) oldSt writeTVar (connectionCurState st) st' disconnectOnFinal st' we readEnd (throwIO (TDSError TDSInvalidStateTransition st' (Just pkt) "The state transitioned to Final before a response could be received")) (pure (getResult readEnd cancel)) getResult readEnd@(ReadEnd sock) cancel = case responseDecoder :: ResponseDecoder (ResponseStreaming r) r of DecodeBatchResponse decode -> do hdr <- alloca $ \hdrP -> do recvExactly sock (hdrP :: Ptr Word64) 8 readHdr (TabularResult :: PacketType 'Server 'NoResponse r) hdrP case hdr of Nothing -> invalidResponse readEnd Just hdr' | not (pktHdrStatus hdr' `hasStatus` pktStatusEndOfMessage) -> fail "Cannot decode batch message split over multiple packets" | otherwise -> allocaBytes (fromIntegral $ pktHdrLength hdr') $ \pktBuf -> do recvExactly sock pktBuf (pktHdrLength hdr' - 8) debugPtr pktBuf (fromIntegral $ pktHdrLength hdr' - 8) resp <- decode pktBuf (pktHdrLength hdr') case resp of Nothing -> fail "Invalid response" Just resp' -> do atomically . withConnectionState stV (pure ()) $ \st -> do writeTVar (connectionReadSock st) (Right readEnd) writeTVar (connectionCurReq st) Nothing pure (ResponseResultReceived resp') DecodeTokenStream streamDecoder -> do let tokenStream = parseTokenStream (dataStream readEnd cancel) validTokens s = do res <- S.lift (S.inspect s) case res of Left a -> pure a Right (OneToken tok next) -> do includeToken <- S.lift (handleToken pkt options stV tok) if includeToken then S.wrap (OneToken tok (validTokens next)) else validTokens next Right (ContParse tok next) -> S.wrap (ContParse tok (validTokens . next)) finishUp = atomically . withConnectionState stV (pure ()) $ \st -> do writeTVar (connectionCurState st) LoggedIn writeTVar (connectionReadSock st) (Right readEnd) writeTVar (connectionCurReq st) Nothing -- TODO statically determine the kind of DONE message to -- expect as the end of this stream -- -- TODO if this throws an exception, we should close the connection res <- streamDecoder finishUp (validTokens tokenStream) pure (ResponseResultReceived res) invalidResponse :: ReadEnd -> IO (ResponseResult ('ResponseType cancelable r)) invalidResponse readEnd = join . atomically . withConnectionState stV (pure $ throwIO (TDSError TDSInvalidResponse Final (Just pkt) "Invalid response received, but we've already quit")) $ \st -> do clientSt <- readTVar (connectionCurState st) writeTVar (connectionCurState st) Final writeTVar (connectionReadSock st) (Right readEnd) quitSTM stV pure (throwIO (TDSError TDSInvalidResponse clientSt (Just pkt) "Invalid response received")) disconnectOnFinal :: ClientState -> WriteEnd -> ReadEnd -> IO a -> IO a -> STM (IO a) disconnectOnFinal Final we re failer _ = quitSTM' stV we re >> pure failer disconnectOnFinal _ _ _ _ action = pure action cancelRequest st writeEnd@(WriteEnd sock buf) = do curReq <- atomically (readTVar (connectionCurReq st)) case curReq of Just (SomeSentPacket (Packet (PacketHeader { pktHdrType = SQLBatch }) _) (Cancelable signal sync)) -> do let bufSz = cancelPacketSize buf <- mallocForeignPtrBytes (fromIntegral bufSz) atomically (writeTVar signal True) atomically (takeTMVar sync) let (_, splitEncoding) = splitPacket bufSz (Packet (mkPacketHeader Attention pktStatusEndOfMessage) (PacketEncoding (encodePayload ()))) sendPackets sock buf bufSz splitEncoding -- TODO create read end, and read until done token let tokenStream = parseTokenStream (dataStream (ReadEnd sock) NonCancelable) readUntilDoneToken tokenStream -- Cancel all atomically $ do writeTVar (connectionCurState st) LoggedIn writeTVar (connectionReadSock st) (Right (ReadEnd sock)) writeTVar (connectionCurReq st) Nothing _ -> pure () -- Spawn a new thread to handle the shutdown -- void . forkIO $ do -- -- Sending a cancel request in the middle of a send operation -- -- This just means sending the packet header with the EOM bit set -- -- -- TODO sendPackets sock _ (pkt { pktData = LastPacket (\_ -> pure 0) }) -- -- atomically $ do -- writeTVar sockV (Just writeEnd) -- writeTVar sentV Nothing -- -- -- We should be back in the old state -- -- fail "TODO figure out cancel while sending request" handleToken :: Packet clientServer respType d f -> Options -> TVar ConnectionState -> Token -> IO Bool handleToken _ options stV (Info msg) = False <$ _tdsOnMessage options msg handleToken _ options stV (Error msg) = False <$ _tdsOnMessage options msg handleToken _ options stV (EnvChange chg) = join . atomically $ do st <- readTVar stV case st of ConnectionQuit -> pure (pure False) ConnectionState {} -> do let st' = st { connectionEnvironment = updateEnv chg (connectionEnvironment st) } writeTVar stV st' pure (False <$ _tdsOnEnvChange options chg) handleToken pkt options stV (LoginAck {}) | Login7 <- pktHdrType (pktHdr pkt) = atomically $ do st <- readTVar stV case st of ConnectionQuit -> pure True ConnectionState { connectionCurState = protoStV } -> do protoSt <- readTVar protoStV let setLoggedIn = writeTVar protoStV LoggedIn case protoSt of SentLogin7WithCompleteAuthenticationToken -> setLoggedIn SentLogin7WithSPNEGO -> setLoggedIn SentLogin7WithFAIR -> setLoggedIn _ -> pure () pure True handleToken _ _ _ _ = pure True cancel :: TVar ConnectionState -> IO () cancel _ = fail "Cancel" quitSTM :: TVar ConnectionState -> STM (IO ()) quitSTM stV = do withConnectionState stV (pure (pure ())) $ \st -> do we <- either (\_ -> retry) pure =<< readTVar (connectionWriteSock st) re <- either (\_ -> retry) pure =<< readTVar (connectionReadSock st) quitSTM' stV we re quitSTM' :: TVar ConnectionState -> WriteEnd -> ReadEnd -> STM (IO ()) quitSTM' stV we re = do withConnectionState stV (pure (pure ())) $ \st -> do writeTVar stV ConnectionQuit pure (closeWriteEnd we >> closeReadEnd re) quit :: TVar ConnectionState -> IO () quit = join . atomically . quitSTM -- * Client environment data ClientEnv = ClientEnv { clientEnvDatabase :: Text , clientEnvLanguage :: Text } deriving Show -- TODO This should work clientEnvFromOptions :: Options -> ClientEnv clientEnvFromOptions _ = ClientEnv "master" "us_english" updateEnv :: EnvChange -> ClientEnv -> ClientEnv updateEnv (EnvChangeDatabase _ new) env = env { clientEnvDatabase = new } updateEnv _ env = env readUntilDoneToken :: S.Stream TokenStream IO () -> IO () readUntilDoneToken s = do res <- S.inspect s case res of Left {} -> pure () Right (OneToken Done {} s') -> readUntilDoneToken s' Right (ContParse Done {} cont) -> pure () Right (ContParse _ f) -> fail "Can't read (TODO)"