{-# LANGUAGE CPP, ForeignFunctionInterface #-}

module Ssh ( grabSSH, runSSH, getSSH, copySSH, copySSHs, SSHCmd(..) ) where

import Prelude hiding ( lookup, catch )

import System.Exit ( ExitCode(..) )
import System.Environment ( getEnv )
#ifndef WIN32
import System.Posix.Process ( getProcessID )
#else
import Darcs.Utils ( showHexLen )
import Data.Bits ( (.&.) )
import System.Random ( randomIO )
#endif
import System.IO ( Handle, hPutStr, hPutStrLn, hGetLine, hGetContents, hClose, hFlush )
import System.IO.Unsafe ( unsafePerformIO )
import System.Directory ( doesFileExist, createDirectoryIfMissing )
import Control.Monad ( when )
import System.Process ( runInteractiveProcess )

import Data.Map ( Map, empty, insert, lookup )
import Data.IORef ( IORef, newIORef, readIORef, modifyIORef )

import Darcs.SignalHandler ( catchNonSignal )
import Darcs.Utils ( withCurrentDirectory, breakCommand, prettyException, catchall )
import Darcs.Global ( atexit, sshControlMasterDisabled, darcsdir, withDebugMode )
import Darcs.Lock ( withTemp, withOpenTemp, tempdir_loc, removeFileMayNotExist )
import Exec ( exec, Redirects, Redirect(..), )
import Progress ( withoutProgress, debugMessage, debugFail, progressList )

import qualified Data.ByteString as B (ByteString, hGet, writeFile, readFile)
import qualified Data.ByteString.Char8 as BC (unpack)

#include "impossible.h"

{-# NOINLINE sshConnections #-}
sshConnections :: IORef (Map String (Maybe Connection))
sshConnections = unsafePerformIO $ newIORef empty

data Connection = C { inp :: !Handle, out :: !Handle, err :: !Handle, deb :: String -> IO () }

withSSHConnection :: String -> (Connection -> IO a) -> IO a -> IO a
withSSHConnection x withconnection withoutconnection =
    withoutProgress $
    do cs <- readIORef sshConnections
       let uhost = takeWhile (/= ':') x
           url = cleanrepourl x
       case lookup url (cs :: Map String (Maybe Connection)) of
         Just Nothing -> withoutconnection
         Just (Just c) -> withconnection c
         Nothing ->
           do mc <- do (ssh,sshargs_) <- getSSHOnly SSH
                       let sshargs = sshargs_ ++ [uhost,"darcs","transfer-mode","--repodir",cleanrepodir x]
                       debugMessage $ "ssh "++unwords sshargs
                       (i,o,e,_) <- runInteractiveProcess ssh sshargs Nothing Nothing
                       l <- hGetLine o
                       if l == "Hello user, I am darcs transfer mode"
                           then return ()
                           else debugFail "Couldn't start darcs transfer-mode on server"
                       let c = C { inp = i, out = o, err = e,
                                   deb = \s -> debugMessage ("with ssh (transfer-mode) "++uhost++": "++s) }
                       modifyIORef sshConnections (insert url (Just c))
                       return $ Just c
                    `catchNonSignal`
                            \e -> do debugMessage $ "Failed to start ssh connection:\n    "++
                                                    prettyException e
                                     severSSHConnection x
                                     debugMessage $ unlines $
                                         [ "NOTE: the server may be running a version of darcs prior to 2.0.0."
                                         , ""
                                         , "Installing darcs 2 on the server will speed up ssh-based commands."
                                         ]
                                     return Nothing
              maybe withoutconnection withconnection mc

severSSHConnection :: String -> IO ()
severSSHConnection x = do debugMessage $ "Severing ssh failed connection to "++x
                          modifyIORef sshConnections (insert (cleanrepourl x) Nothing)

cleanrepourl :: String -> String
cleanrepourl zzz | take (length dd) zzz == dd = ""
                 where dd = darcsdir++"/"
cleanrepourl (z:zs) = z : cleanrepourl zs
cleanrepourl "" = ""

cleanrepodir :: String -> String
cleanrepodir = cleanrepourl . drop 1 . dropWhile (/= ':')

grabSSH :: String -> Connection -> IO B.ByteString
grabSSH x c = do
               let dir = drop 1 $ dropWhile (/= ':') x
                   dd = darcsdir++"/"
                   clean zzz | take (length dd) zzz == dd = drop (length dd) zzz
                   clean (_:zs) = clean zs
                   clean "" = bug $ "Buggy path in grabSSH: "++x
                   file = clean dir
                   failwith e = do severSSHConnection x
                                   eee <- hGetContents (err c) -- ratify hGetContents: it's okay
                                                               -- here because we're only grabbing
                                                               -- stderr, and we're also about to
                                                               -- throw the contents.
                                   debugFail $ e ++ " grabbing ssh file "++x++"\n"++eee
               deb c $ "get "++file
               hPutStrLn (inp c) $ "get " ++ file
               hFlush (inp c)
               l2 <- hGetLine (out c)
               if l2 == "got "++file
                  then do showlen <- hGetLine (out c)
                          case reads showlen of
                            [(len,"")] -> B.hGet (out c) len
                            _ -> failwith "Couldn't get length"
                  else if l2 == "error "++file
                       then do e <- hGetLine (out c)
                               case reads e of
                                 (msg,_):_ -> debugFail $ "Error reading file remotely:\n"++msg
                                 [] -> failwith "An error occurred"
                       else failwith "Error"

sshStdErrMode :: IO Redirect
sshStdErrMode = withDebugMode $ \amdebugging ->
                return $ if amdebugging then AsIs else Null

copySSH :: String -> FilePath -> IO ()
copySSH uRaw f = withSSHConnection uRaw (\c -> grabSSH uRaw c >>= B.writeFile f) $
              do let u = escape_dollar uRaw
                 stderr_behavior <- sshStdErrMode
                 r <- runSSH SCP u [] [u,f] (AsIs,AsIs,stderr_behavior)
                 when (r /= ExitSuccess) $
                      debugFail $ "(scp) failed to fetch: " ++ u
    where {- '$' in filenames is troublesome for scp, for some reason.. -}
          escape_dollar :: String -> String
          escape_dollar = concatMap tr
           where tr '$' = "\\$"
                 tr c = [c]

copySSHs :: String -> [String] -> FilePath -> IO ()
copySSHs u ns d =
  withSSHConnection u (\c -> withCurrentDirectory d $
                             mapM_ (\n -> grabSSH (u++"/"++n) c >>= B.writeFile n) $
                             progressList "Copying via ssh" ns) $
  do let path = drop 1 $ dropWhile (/= ':') u
         host = takeWhile (/= ':') u
         cd = "cd "++path++"\n"
         input = cd++(unlines $ map ("get "++) ns)
     withCurrentDirectory d $ withOpenTemp $ \(th,tn) ->
         withTemp $ \sftpoutput ->
         do hPutStr th input
            hClose th
            stderr_behavior <- sshStdErrMode
            r <- runSSH SFTP u [] [host] (File tn, File sftpoutput, stderr_behavior)
            let files = if length ns > 5
                          then (take 5 ns) ++ ["and "
                               ++ (show (length ns - 5)) ++ " more"]
                          else ns
                hint = if take 1 path == "~"
                         then ["sftp doesn't expand ~, use path/ instead of ~/path/"]
                         else []
            when (r /= ExitSuccess) $ do
                 outputPS <- B.readFile sftpoutput
                 debugFail $ unlines $
                          ["(sftp) failed to fetch files.",
                           "source directory: " ++ path,
                           "source files:"] ++ files ++
                          ["sftp output:",BC.unpack outputPS] ++
                          hint

-- ---------------------------------------------------------------------
-- older ssh helper functions
-- ---------------------------------------------------------------------

data SSHCmd = SSH | SCP | SFTP

instance Show SSHCmd where
  show SSH  = "ssh"
  show SCP  = "scp"
  show SFTP = "sftp"

runSSH :: SSHCmd -> String -> [String] -> [String] -> Redirects -> IO ExitCode
runSSH cmd remoteAddr preArgs postArgs redirs =
 do (ssh, args) <- getSSH cmd remoteAddr
    exec ssh (preArgs ++ args ++ postArgs) redirs

-- | Return the command and arguments needed to run an ssh command
--   along with any extra features like use of the control master.
--   See 'getSSHOnly'
getSSH :: SSHCmd -> String -- ^ remote path
       -> IO (String, [String])
getSSH cmd remoteAddr =
 do (ssh, ssh_args) <- getSSHOnly cmd
    cm_args <- if sshControlMasterDisabled
               then return []
               else do -- control master
                       cmPath <- controlMasterPath remoteAddr
                       hasLaunchedCm <- doesFileExist cmPath
                       when (not hasLaunchedCm) $ launchSSHControlMaster remoteAddr
                       hasCmFeature <- doesFileExist cmPath
                       return $ if hasCmFeature then [ "-o ControlPath=" ++ cmPath ] else []
    let verbosity = case cmd of
                    SCP  -> ["-q"] -- (p)scp is the only one that recognises -q
                                   -- sftp and (p)sftp do not, and plink neither
                    _    -> []
    --
    return (ssh, verbosity ++ ssh_args ++ cm_args)

-- | Return the command and arguments needed to run an ssh command.
--   First try the appropriate darcs environment variable and SSH_PORT
--   defaulting to "ssh" and no specified port.
getSSHOnly :: SSHCmd -> IO (String, [String])
getSSHOnly cmd =
 do ssh_command <- getEnv (evar cmd) `catchall` return (show cmd)
    -- port
    port <- (portFlag cmd `fmap` getEnv "SSH_PORT") `catchall` return []
    let (ssh, ssh_args) = breakCommand ssh_command
    --
    return (ssh, ssh_args ++ port)
    where
     evar SSH  = "DARCS_SSH"
     evar SCP  = "DARCS_SCP"
     evar SFTP = "DARCS_SFTP"
     portFlag SSH  x = ["-p", x]
     portFlag SCP  x = ["-P", x]
     portFlag SFTP x = ["-oPort="++x]

-- | Return True if this version of ssh has a ControlMaster feature
-- The ControlMaster functionality allows for ssh multiplexing
hasSSHControlMaster :: Bool
hasSSHControlMaster = unsafePerformIO hasSSHControlMasterIO

-- Because of the unsafePerformIO above, this can be called at any
-- point.  It cannot rely on any state, not even the current directory.
hasSSHControlMasterIO :: IO Bool
hasSSHControlMasterIO = do
  (ssh, _) <- getSSHOnly SSH
  -- If ssh has the ControlMaster feature, it will recognise the
  -- the -O flag, but exit with status 255 because of the nonsense
  -- command.  If it does not have the feature, it will simply dump
  -- a help message on the screen and exit with 1.
  sx <- exec ssh ["-O", "an_invalid_command"] (Null,Null,Null)
  case sx of
    ExitFailure 255 -> return True
    _ -> return False

-- | Launch an SSH control master in the background, if available.
--   We don't have to wait for it or anything.
--   Note also that this will cleanup after itself when darcs exits
launchSSHControlMaster :: String -> IO ()
launchSSHControlMaster rawAddr =
  when hasSSHControlMaster $ do
  let addr = takeWhile (/= ':') rawAddr
  (ssh, ssh_args) <- getSSHOnly SSH
  cmPath <- controlMasterPath addr
  removeFileMayNotExist cmPath
  -- -f : put ssh in the background once it succeeds in logging you in
  -- -M : launch as the control master for addr
  -- -N : don't run any commands
  -- -S : use cmPath as the ControlPath.  Equivalent to -oControlPath=
  exec ssh (ssh_args ++ [addr, "-S", cmPath, "-N", "-f", "-M"]) (Null,Null,AsIs)
  atexit $ exitSSHControlMaster addr
  return ()

-- | Tell the SSH control master for a given path to exit.
exitSSHControlMaster :: String -> IO ()
exitSSHControlMaster addr = do
  (ssh, ssh_args) <- getSSHOnly SSH
  cmPath <- controlMasterPath addr
  exec ssh (ssh_args ++ [addr, "-S", cmPath, "-O", "exit"]) (Null,Null,Null)
  return ()

-- | Create the directory ssh control master path for a given address
controlMasterPath :: String -- ^ remote path (foo\@bar.com:file is ok; the file part with be stripped)
                  -> IO FilePath
controlMasterPath rawAddr = do
  let addr = takeWhile (/= ':') rawAddr
  tmp <- (fmap (/// ".darcs") $ getEnv "HOME") `catchall` tempdir_loc
#ifdef WIN32
  r <- randomIO
  let suffix = (showHexLen 6 (r .&. 0xFFFFFF :: Int))
#else
  suffix <- show `fmap` getProcessID
#endif
  let tmpDarcsSsh = tmp /// "darcs-ssh"
  createDirectoryIfMissing True tmpDarcsSsh
  return $ tmpDarcsSsh /// addr ++ suffix

(///) :: FilePath -> FilePath -> FilePath
d /// f = d ++ "/" ++ f