module Propellor.Ssh where

import Propellor.Base
import Utility.UserInfo
import Utility.FileSystemEncoding

import System.PosixCompat
import Data.Time.Clock.POSIX
import Data.Hashable

-- Parameters can be passed to both ssh and scp, to enable a ssh connection
-- caching socket.
--
-- If the socket already exists, check if its mtime is older than 10
-- minutes, and if so stop that ssh process, in order to not try to
-- use an old stale connection. (atime would be nicer, but there's
-- a good chance a laptop uses noatime)
sshCachingParams :: HostName -> IO [CommandParam]
sshCachingParams :: HostName -> IO [CommandParam]
sshCachingParams HostName
hn = do
	HostName
home <- IO HostName
myHomeDir
	let socketfile :: HostName
socketfile = HostName -> HostName -> HostName
socketFile HostName
home HostName
hn
	Bool -> HostName -> IO ()
createDirectoryIfMissing Bool
True (HostName -> HostName
takeDirectory HostName
socketfile)
	let ps :: [CommandParam]
ps =
		[ HostName -> CommandParam
Param HostName
"-o"
		, HostName -> CommandParam
Param (HostName
"ControlPath=" HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++ HostName
socketfile)
		, HostName -> CommandParam
Param HostName
"-o", HostName -> CommandParam
Param HostName
"ControlMaster=auto"
		, HostName -> CommandParam
Param HostName
"-o", HostName -> CommandParam
Param HostName
"ControlPersist=yes"
		]

	IO () -> (FileStatus -> IO ()) -> Maybe FileStatus -> IO ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO ()
forall (m :: * -> *). Monad m => m ()
noop ([CommandParam] -> HostName -> FileStatus -> IO ()
expireold [CommandParam]
ps HostName
socketfile)
		(Maybe FileStatus -> IO ()) -> IO (Maybe FileStatus) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO FileStatus -> IO (Maybe FileStatus)
forall (m :: * -> *) a. MonadCatch m => m a -> m (Maybe a)
catchMaybeIO (HostName -> IO FileStatus
getFileStatus HostName
socketfile)
	
	[CommandParam] -> IO [CommandParam]
forall (m :: * -> *) a. Monad m => a -> m a
return [CommandParam]
ps
		
  where
	expireold :: [CommandParam] -> HostName -> FileStatus -> IO ()
expireold [CommandParam]
ps HostName
f FileStatus
s = do
		Integer
now <- POSIXTime -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate (POSIXTime -> Integer) -> IO POSIXTime -> IO Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO POSIXTime
getPOSIXTime :: IO Integer
		if FileStatus -> EpochTime
modificationTime FileStatus
s EpochTime -> EpochTime -> Bool
forall a. Ord a => a -> a -> Bool
> Integer -> EpochTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
now EpochTime -> EpochTime -> EpochTime
forall a. Num a => a -> a -> a
- EpochTime
tenminutes
			then HostName -> IO ()
touchFile HostName
f
			else do
				IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName -> [CommandParam] -> IO Bool
boolSystem HostName
"ssh" ([CommandParam] -> IO Bool) -> [CommandParam] -> IO Bool
forall a b. (a -> b) -> a -> b
$
					[ HostName -> CommandParam
Param HostName
"-O", HostName -> CommandParam
Param HostName
"stop" ] [CommandParam] -> [CommandParam] -> [CommandParam]
forall a. [a] -> [a] -> [a]
++ [CommandParam]
ps [CommandParam] -> [CommandParam] -> [CommandParam]
forall a. [a] -> [a] -> [a]
++
					[ HostName -> CommandParam
Param HostName
"localhost" ]
				HostName -> IO ()
nukeFile HostName
f
	tenminutes :: EpochTime
tenminutes = EpochTime
600

-- Generate a socket filename inside the home directory.
--
-- There's a limit in the size of unix domain sockets, of approximately
-- 100 bytes. Try to never construct a filename longer than that.
--
-- When space allows, include the full hostname in the socket filename.
-- Otherwise, a checksum of the hostname is included in the name, to
-- avoid using the same socket file for multiple hosts.
socketFile :: FilePath -> HostName -> FilePath
socketFile :: HostName -> HostName -> HostName
socketFile HostName
home HostName
hn = [HostName] -> HostName -> HostName
selectSocketFile
	[ HostName
sshdir HostName -> HostName -> HostName
</> HostName
hn HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++ HostName
".sock"
	, HostName
sshdir HostName -> HostName -> HostName
</> HostName
hn
	, HostName
sshdir HostName -> HostName -> HostName
</> Int -> HostName -> HostName
forall a. Int -> [a] -> [a]
take Int
10 HostName
hn HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++ HostName
"-" HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++ HostName
checksum
	, HostName
sshdir HostName -> HostName -> HostName
</> HostName
checksum
	]
	(HostName
home HostName -> HostName -> HostName
</> HostName
".propellor-" HostName -> HostName -> HostName
forall a. [a] -> [a] -> [a]
++ HostName
checksum)
  where
	sshdir :: HostName
sshdir = HostName
home HostName -> HostName -> HostName
</> HostName
".ssh" HostName -> HostName -> HostName
</> HostName
"propellor"
	checksum :: HostName
checksum = Int -> HostName -> HostName
forall a. Int -> [a] -> [a]
take Int
9 (HostName -> HostName) -> HostName -> HostName
forall a b. (a -> b) -> a -> b
$ Int -> HostName
forall a. Show a => a -> HostName
show (Int -> HostName) -> Int -> HostName
forall a b. (a -> b) -> a -> b
$ Int -> Int
forall a. Num a => a -> a
abs (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ HostName -> Int
forall a. Hashable a => a -> Int
hash HostName
hn

selectSocketFile :: [FilePath] -> FilePath -> FilePath
selectSocketFile :: [HostName] -> HostName -> HostName
selectSocketFile [] HostName
d = HostName
d
selectSocketFile (HostName
f:[HostName]
fs) HostName
d
	| HostName -> Bool
valid_unix_socket_path HostName
f = HostName
f
	| Bool
otherwise = [HostName] -> HostName -> HostName
selectSocketFile [HostName]
fs HostName
d

valid_unix_socket_path :: FilePath -> Bool
valid_unix_socket_path :: HostName -> Bool
valid_unix_socket_path HostName
f = [Word8] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (HostName -> [Word8]
decodeW8 HostName
f) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
100 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
reservedbyssh
  where
	-- ssh tacks on 17 or so characters when making a socket
	reservedbyssh :: Int
reservedbyssh = Int
18