{-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeFamilies #-} module Network.SSH.Key ( KeyPair (..) , newKeyPair , PublicKey (..) , decodePrivateKeyFile , toPublicKey ) where import Control.Applicative (many, (<|>)) import Control.Monad (replicateM, void, when) import Control.Monad.Fail (MonadFail) import qualified Crypto.Cipher.AES as Cipher import qualified Crypto.Cipher.Types as Cipher import Crypto.Error -- import qualified Crypto.KDF.BCryptPBKDF as BCryptPBKDF import qualified Crypto.PubKey.Ed25519 as Ed25519 import qualified Crypto.PubKey.RSA as RSA import Data.Bits import qualified Data.ByteArray as BA import qualified Data.ByteArray.Parse as BP import qualified Data.ByteString as BS import Data.String import Data.Word import Network.SSH.Name data KeyPair = KeyPairEd25519 Ed25519.PublicKey Ed25519.SecretKey deriving (Eq, Show) data PublicKey = PublicKeyEd25519 Ed25519.PublicKey | PublicKeyRSA RSA.PublicKey | PublicKeyOther Name deriving (Eq, Show) instance HasName PublicKey where name PublicKeyEd25519 {} = Name "ssh-ed25519" name PublicKeyRSA {} = Name "ssh-rsa" name (PublicKeyOther n) = n newKeyPair :: IO KeyPair newKeyPair = (\sk -> KeyPairEd25519 (Ed25519.toPublic sk) sk) <$> Ed25519.generateSecretKey toPublicKey :: KeyPair -> PublicKey toPublicKey (KeyPairEd25519 pk _) = PublicKeyEd25519 pk decodePrivateKeyFile :: ( MonadFail m, BA.ByteArray input, BA.ByteArrayAccess passphrase, BA.ByteArray comment ) => passphrase -> input -> m [(KeyPair, comment)] decodePrivateKeyFile passphrase = f . BP.parse (parsePrivateKeyFile passphrase) . BA.convert where f (BP.ParseOK _ a) = pure a f (BP.ParseFail e) = fail e f (BP.ParseMore c) = f (c Nothing) parsePrivateKeyFile :: ( BA.ByteArrayAccess passphrase, BA.ByteArray comment ) => passphrase -> BP.Parser BS.ByteString [(KeyPair, comment)] parsePrivateKeyFile _passphrase = do BP.bytes "-----BEGIN OPENSSH PRIVATE KEY-----" void $ many space bs <- parseBase64 void $ many space BP.bytes "-----END OPENSSH PRIVATE KEY-----" void $ many space BP.hasMore >>= flip when syntaxError case BP.parse parseKeys bs of BP.ParseOK _ keys -> pure keys BP.ParseFail e -> fail e BP.ParseMore _ -> syntaxError where syntaxError :: BP.Parser ba a syntaxError = fail "Syntax error" parseBase64 :: (BA.ByteArray ba) => BP.Parser ba ba parseBase64 = s0 [] where -- Initial state and final state. s0 xs = (char >>= s1 xs) <|> (space1 >> s0 xs) <|> pure (BA.pack $ reverse xs) -- One character read (i). Three more characters or whitespace expected. s1 xs i = (char >>= s2 xs i) <|> (space1 >> s1 xs i) -- Two characters read (i and j). Either '==' or space or two more character expected. s2 xs i j = r2 xs i j <|> (char >>= s3 xs i j) <|> (space1 >> s2 xs i j) -- Three characters read (i, j and k). Either a '=' or space or one more character expected. s3 xs i j k = r3 xs i j k <|> (char >>= s4 xs i j k) <|> (space1 >> s3 xs i j k) -- Four characters read (i, j, k and l). Computation of result and transition back to s0. s4 xs i j k l = s0 $ byte3 : byte2 : byte1: xs where byte1 = ( i `shiftL` 2) + (j `shiftR` 4) byte2 = ((j .&. 15) `shiftL` 4) + (k `shiftR` 2) byte3 = ((k .&. 3) `shiftL` 6) + l -- Read two '=' chars as finalizer. Only valid from state s2. r2 xs i j = padding >> padding >> pure (BA.pack $ reverse $ byte1 : xs) where byte1 = (i `shiftL` 2) + (j `shiftR` 4) -- Read one '=' char as finalizer. Only valid from state s1. r3 xs i j k = padding >> pure (BA.pack $ reverse $ byte2 : byte1 : xs) where byte1 = (i `shiftL` 2) + (j `shiftR` 4) byte2 = ((j .&. 15) `shiftL` 4) + (k `shiftR` 2) char :: (BA.ByteArray ba) => BP.Parser ba Word8 char = BP.anyByte >>= \c-> if | c >= fe 'A' && c <= fe 'Z' -> pure (c - fe 'A') | c >= fe 'a' && c <= fe 'z' -> pure (c - fe 'a' + 26) | c >= fe '0' && c <= fe '9' -> pure (c - fe '0' + 52) | c == fe '+' -> pure 62 | c == fe '/' -> pure 63 | otherwise -> fail "" padding :: (BA.ByteArray ba) => BP.Parser ba () padding = BP.byte 61 -- 61 == fromEnum '=' fe :: Char -> Word8 fe = fromIntegral . fromEnum space :: (BA.ByteArray ba) => BP.Parser ba () space = BP.anyByte >>= \c-> if | c == fe ' ' -> pure () | c == fe '\n' -> pure () | c == fe '\r' -> pure () | c == fe '\t' -> pure () | otherwise -> fail "" space1 :: (BA.ByteArray ba) => BP.Parser ba () space1 = space >> many space >> pure () getWord32be :: BA.ByteArray ba => BP.Parser ba Word32 getWord32be = do x0 <- fromIntegral <$> BP.anyByte x1 <- fromIntegral <$> BP.anyByte x2 <- fromIntegral <$> BP.anyByte x3 <- fromIntegral <$> BP.anyByte pure $ shiftR x0 24 .|. shiftR x1 16 .|. shiftR x2 8 .|. x3 getString :: BA.ByteArray ba => BP.Parser ba ba getString = BP.take =<< (fromIntegral <$> getWord32be) parseKeys :: (BA.ByteArray input, IsString input, Show input, BA.ByteArray comment) => BP.Parser input [(KeyPair, comment)] parseKeys = do BP.bytes "openssh-key-v1\NUL" cipherAlgo <- getString kdfAlgo <- getString BP.skip 4 -- size of the kdf section deriveKey <- case kdfAlgo of "none" -> pure $ \_-> CryptoFailed CryptoError_KeySizeInvalid {- -- This is currently not included in cryptonite. -- Re-enable if my PR has been merged. "bcrypt" -> do salt <- getString rounds <- fromIntegral <$> getWord32be pure $ \case Cipher.KeySizeFixed len -> CryptoPassed $ BCryptPBKDF.generate (BCryptPBKDF.Parameters rounds len) (BA.convert passphrase :: BA.Bytes) salt _ -> undefined -- impossible -} _ -> fail $ "Unsupported key derivation function " ++ show (BA.convert kdfAlgo :: BA.Bytes) numberOfKeys <- fromIntegral <$> getWord32be _publicKeysRaw <- getString -- not used privateKeysRawEncrypted <- getString privateKeysRawDecrypted <- BA.convert <$> case cipherAlgo of "none" -> pure privateKeysRawEncrypted "aes256-cbc" -> do let result = do let Cipher.KeySizeFixed keySize = Cipher.cipherKeySize (undefined :: Cipher.AES256) ivSize = Cipher.blockSize (undefined :: Cipher.AES256) keyIV <- deriveKey $ Cipher.KeySizeFixed (keySize + ivSize) let key = BA.take keySize keyIV :: BA.ScrubbedBytes case Cipher.makeIV (BA.drop keySize keyIV) of Nothing -> CryptoFailed CryptoError_IvSizeInvalid Just iv -> do cipher <- Cipher.cipherInit key :: CryptoFailable Cipher.AES256 pure $ Cipher.cbcDecrypt cipher iv privateKeysRawEncrypted case result of CryptoPassed a -> pure a CryptoFailed e -> fail (show e) "aes256-ctr" -> do let result = do let Cipher.KeySizeFixed keySize = Cipher.cipherKeySize (undefined :: Cipher.AES256) let ivSize = Cipher.blockSize (undefined :: Cipher.AES256) keyIV <- deriveKey $ Cipher.KeySizeFixed (keySize + ivSize) let key = BA.take keySize keyIV :: BA.ScrubbedBytes case Cipher.makeIV (BA.drop keySize keyIV) of Nothing -> CryptoFailed CryptoError_IvSizeInvalid Just iv -> do cipher <- Cipher.cipherInit key :: CryptoFailable Cipher.AES256 pure $ Cipher.ctrCombine cipher iv privateKeysRawEncrypted case result of CryptoPassed a -> pure a CryptoFailed e -> fail (show e) _ -> fail $ "Unsupported cipher " ++ show cipherAlgo case BP.parse (parsePrivateKeys numberOfKeys) privateKeysRawDecrypted of BP.ParseOK _ keys -> pure keys BP.ParseFail e -> fail e BP.ParseMore _ -> syntaxError parsePrivateKeys :: (BA.ByteArray comment) => Int -> BP.Parser BA.ScrubbedBytes [(KeyPair, comment)] parsePrivateKeys count = do check1 <- getWord32be check2 <- getWord32be when (check1 /= check2) (fail "Unsuccessful decryption") replicateM count $ do key <- getString >>= \algo-> case algo of "ssh-ed25519" -> do BP.skip 3 BP.byte 32 -- length field (is always 32 for ssh-ed25519) BP.skip Ed25519.publicKeySize BP.skip 3 BP.byte 64 -- length field (is always 64 for ssh-ed25519) secretKeyRaw <- BP.take 32 publicKeyRaw <- BP.take 32 let key = KeyPairEd25519 <$> Ed25519.publicKey publicKeyRaw <*> Ed25519.secretKey secretKeyRaw case key of CryptoPassed a -> pure a CryptoFailed _ -> fail $ "Invalid " ++ show (BA.convert algo :: BA.Bytes) ++ " key" _ -> fail $ "Unsupported algorithm " ++ show (BA.convert algo :: BA.Bytes) comment <- BA.convert <$> getString pure (key, comment)