{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE CPP #-}

module Data.Conduit.RemoteOp
    ( remoteOp
    , OpOutputType(..)
#ifdef TEST
    , sshargs
#endif
    )
where

import           Control.Concurrent (readChan, writeChan, newChan)
import           Control.Concurrent.Async (async)
import qualified Control.Exception as E
import           Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as S
import           Data.Conduit
import qualified Data.Conduit.List as CL
import           Data.Conduit.Process ( streamingProcess, proc, ClosedStream(..)
                            , waitForStreamingProcess)
import           Data.Monoid ((<>))
import qualified Data.Text as T
import           System.Exit (ExitCode(..))


sshargs :: Bool -> T.Text -> [T.Text] -> [T.Text]
sshargs directSSH host command =
    let stdArgs = "-A"
                  : "-o" : "ControlPath none"
                  : "-o" : "VisualHostKey no"
                  -- : "-n"
                  -- : "-o" : "BatchMode yes"
                  : "-o" : "KbdInteractiveAuthentication no"
                  : "-o" : "StrictHostKeyChecking no"
                  : "-o" : "CheckHostIP no"
                  : "-o" : "ForwardX11 no"
                  : "-o" : "ForwardX11Trusted no" : []
        bounceArgs = ["-t", "-t"] <> stdArgs <>
                     ["localhost", "ssh"] <>
                     map (\e -> T.append "'" (T.append e "'")) stdArgs
        targetArgs = [host, T.unlines command]
    in (if directSSH then stdArgs else bounceArgs) ++ targetArgs


data OpOutputType t e = StdOut t
                      | StdErr e
                      | StdOutEnd
                      | StdErrEnd
                      | Ended ExitCode
                      | DebugOut T.Text
                      deriving Show


remoteOp :: MonadIO m
            => Bool -> T.Text
                    -> [T.Text]
                    -> Producer IO S.ByteString
                    -> Producer m (OpOutputType S.ByteString S.ByteString)
remoteOp directSSH host command srcConduit = do
  -- yield (DebugOut rmtProcStr)
  (toProcess, fromProcess, procErrors, cph) <- streamingProcess rmtProc
  mChan <- liftIO newChan
  let showStdout = fromProcess $$ CL.mapM_ (writeChan mChan . StdOut)
      showStderr = procErrors $$ CL.mapM_ (writeChan mChan . StdErr)
      runInput = srcConduit $$ toProcess
      handleStdout = showStdout >> writeChan mChan StdOutEnd
      handleStderr = showStderr >> writeChan mChan StdOutEnd
  liftIO $ async handleStdout
  liftIO $ async handleStderr
  liftIO $ async runInput
  yield <$> Ended =<< chanSrcMid cph mChan
    where
      rmtProc = proc "ssh" $ map T.unpack $ sshargs directSSH host command
      rmtProcStr = T.append "ssh " (T.intercalate " " $ sshargs directSSH host command)

      -- chanSrcMid will wait for and complete with the process exit
      -- code after both stderr and stdout are closed (which the process
      -- must do to exit).  This does not wait for stdin to close since
      -- that can be dealt with asynchronously, but it does completely
      -- drain the process' output and error streams before completing.
      chanSrcMid ph = chanSrcMd' ph (2::Int)
      chanSrcMd' ph 0 _ = liftIO $ waitForProcess' ph
      chanSrcMd' ph n c = do rd <- liftIO $ readChan c
                             case rd of
                               StdOut s -> yield (StdOut s) >> chanSrcMd' ph n c
                               StdErr s -> yield (StdErr s) >> chanSrcMd' ph n c
                               StdOutEnd -> chanSrcMd' ph (n-1) c
                               StdErrEnd -> chanSrcMd' ph (n-1) c

      waitForProcess' ph =
          waitForStreamingProcess ph `E.catch` \(E.SomeException _) -> return ExitSuccess