module Language.Haskell.LSP.Test.Replay
  ( replaySession
  )
where
import           Prelude hiding (id)
import           Control.Concurrent
import           Control.Monad.IO.Class
import qualified Data.ByteString.Lazy.Char8    as B
import qualified Data.Text                     as T
import           Language.Haskell.LSP.Capture
import           Language.Haskell.LSP.Messages
import           Language.Haskell.LSP.Types
import           Language.Haskell.LSP.Types.Lens as LSP hiding (error)
import           Data.Aeson
import           Data.Default
import           Data.List
import           Data.Maybe
import           Control.Lens hiding (List)
import           Control.Monad
import           System.FilePath
import           System.IO
import           Language.Haskell.LSP.Test
import           Language.Haskell.LSP.Test.Files
import           Language.Haskell.LSP.Test.Decoding
import           Language.Haskell.LSP.Test.Messages
import           Language.Haskell.LSP.Test.Server
import           Language.Haskell.LSP.Test.Session
replaySession :: String 
              -> FilePath 
              -> IO ()
replaySession serverExe sessionDir = do
  entries <- B.lines <$> B.readFile (sessionDir </> "session.log")
  
  let unswappedEvents = map (fromJust . decode) entries
  withServer serverExe False $ \serverIn serverOut pid -> do
    events <- swapCommands pid <$> swapFiles sessionDir unswappedEvents
    let clientEvents = filter isClientMsg events
        serverEvents = filter isServerMsg events
        clientMsgs = map (\(FromClient _ msg) -> msg) clientEvents
        serverMsgs = filter (not . shouldSkip) $ map (\(FromServer _ msg) -> msg) serverEvents
        requestMap = getRequestMap clientMsgs
    reqSema <- newEmptyMVar
    rspSema <- newEmptyMVar
    passSema <- newEmptyMVar
    mainThread <- myThreadId
    sessionThread <- liftIO $ forkIO $
      runSessionWithHandles serverIn
                            serverOut
                            (listenServer serverMsgs requestMap reqSema rspSema passSema mainThread)
                            def
                            fullCaps
                            sessionDir
                            (sendMessages clientMsgs reqSema rspSema)
    takeMVar passSema
    killThread sessionThread
  where
    isClientMsg (FromClient _ _) = True
    isClientMsg _                = False
    isServerMsg (FromServer _ _) = True
    isServerMsg _                = False
sendMessages :: [FromClientMessage] -> MVar LspId -> MVar LspIdRsp -> Session ()
sendMessages [] _ _ = return ()
sendMessages (nextMsg:remainingMsgs) reqSema rspSema =
  handleClientMessage request response notification nextMsg
 where
  
  notification msg@(NotificationMessage _ Exit _) = do
    liftIO $ putStrLn "Will send exit notification soon"
    liftIO $ threadDelay 10000000
    sendMessage msg
    liftIO $ error "Done"
  notification msg@(NotificationMessage _ m _) = do
    sendMessage msg
    liftIO $ putStrLn $ "Sent a notification " ++ show m
    sendMessages remainingMsgs reqSema rspSema
  request msg@(RequestMessage _ id m _) = do
    sendRequestMessage msg
    liftIO $ putStrLn $  "Sent a request id " ++ show id ++ ": " ++ show m ++ "\nWaiting for a response"
    rsp <- liftIO $ takeMVar rspSema
    when (responseId id /= rsp) $
      error $ "Expected id " ++ show id ++ ", got " ++ show rsp
    sendMessages remainingMsgs reqSema rspSema
  response msg@(ResponseMessage _ id _ _) = do
    liftIO $ putStrLn $ "Waiting for request id " ++ show id ++ " from the server"
    reqId <- liftIO $ takeMVar reqSema
    if responseId reqId /= id
      then error $ "Expected id " ++ show reqId ++ ", got " ++ show reqId
      else do
        sendResponse msg
        liftIO $ putStrLn $ "Sent response to request id " ++ show id
    sendMessages remainingMsgs reqSema rspSema
sendRequestMessage :: (ToJSON a, ToJSON b) => RequestMessage ClientMethod a b -> Session ()
sendRequestMessage req = do
  
  reqMap <- requestMap <$> ask
  liftIO $ modifyMVar_ reqMap $
    \r -> return $ updateRequestMap r (req ^. LSP.id) (req ^. method)
  sendMessage req
isNotification :: FromServerMessage -> Bool
isNotification (NotPublishDiagnostics      _) = True
isNotification (NotLogMessage              _) = True
isNotification (NotShowMessage             _) = True
isNotification (NotCancelRequestFromServer _) = True
isNotification _                              = False
listenServer :: [FromServerMessage]
             -> RequestMap
             -> MVar LspId
             -> MVar LspIdRsp
             -> MVar ()
             -> ThreadId
             -> Handle
             -> SessionContext
             -> IO ()
listenServer [] _ _ _ passSema _ _ _ = putMVar passSema ()
listenServer expectedMsgs reqMap reqSema rspSema passSema mainThreadId serverOut ctx = do
  msgBytes <- getNextMessage serverOut
  let msg = decodeFromServerMsg reqMap msgBytes
  handleServerMessage request response notification msg
  if shouldSkip msg
    then listenServer expectedMsgs reqMap reqSema rspSema passSema mainThreadId serverOut ctx
    else if inRightOrder msg expectedMsgs
      then listenServer (delete msg expectedMsgs) reqMap reqSema rspSema passSema mainThreadId serverOut ctx
      else let remainingMsgs = takeWhile (not . isNotification) expectedMsgs
                ++ [head $ dropWhile isNotification expectedMsgs]
               exc = ReplayOutOfOrder msg remainingMsgs
            in liftIO $ throwTo mainThreadId exc
  where
  response :: ResponseMessage a -> IO ()
  response res = do
    putStrLn $ "Got response for id " ++ show (res ^. id)
    putMVar rspSema (res ^. id) 
  request :: RequestMessage ServerMethod a b -> IO ()
  request req = do
    putStrLn
      $  "Got request for id "
      ++ show (req ^. id)
      ++ " "
      ++ show (req ^. method)
    putMVar reqSema (req ^. id) 
  notification :: NotificationMessage ServerMethod a -> IO ()
  notification n = putStrLn $ "Got notification " ++ show (n ^. method)
inRightOrder :: FromServerMessage -> [FromServerMessage] -> Bool
inRightOrder _ [] = error "Why is this empty"
inRightOrder received (expected : msgs)
  | received == expected               = True
  | isNotification expected            = inRightOrder received msgs
  | otherwise                          = False
shouldSkip :: FromServerMessage -> Bool
shouldSkip (NotLogMessage  _) = True
shouldSkip (NotShowMessage _) = True
shouldSkip (ReqShowMessage _) = True
shouldSkip _                  = False
swapCommands :: Int -> [Event] -> [Event]
swapCommands _ [] = []
swapCommands pid (FromClient t (ReqExecuteCommand req):xs) =  FromClient t (ReqExecuteCommand swapped):swapCommands pid xs
  where swapped = params . command .~ newCmd $ req
        newCmd = swapPid pid (req ^. params . command)
swapCommands pid (FromServer t (RspInitialize rsp):xs) = FromServer t (RspInitialize swapped):swapCommands pid xs
  where swapped = case newCommands of
          Just cmds -> result . _Just . LSP.capabilities . executeCommandProvider . _Just . commands .~ cmds $ rsp
          Nothing -> rsp
        oldCommands = rsp ^? result . _Just . LSP.capabilities . executeCommandProvider . _Just . commands
        newCommands = fmap (fmap (swapPid pid)) oldCommands
swapCommands pid (x:xs) = x:swapCommands pid xs
hasPid :: T.Text -> Bool
hasPid = (>= 2) . T.length . T.filter (':' ==)
swapPid :: Int -> T.Text -> T.Text
swapPid pid t
  | hasPid t = T.append (T.pack $ show pid) $ T.dropWhile (/= ':') t
  | otherwise = t