module Crypto.PubKey.OpenSsh.Decode where
import Prelude hiding (take)
import Control.Applicative ((*>), (<|>))
import Control.Monad (void, replicateM)
import Data.ByteString.Char8 (ByteString)
import Data.Char (isControl)
import Data.Attoparsec.ByteString.Char8 (Parser, parseOnly, take, space,
isSpace, takeTill)
import Data.PEM (PEM(..), pemParseBS)
import Data.ASN1.Encoding (decodeASN1')
import Data.ASN1.Types (ASN1(IntVal, Start, End), ASN1ConstructionType(Sequence))
import Data.ASN1.BinaryEncoding (DER(..))
import Data.Serialize (Get, getBytes, runGet, getWord32be, getWord8)
import qualified Data.ByteString.Base64 as Base64
import qualified Crypto.Types.PubKey.DSA as DSA
import qualified Crypto.Types.PubKey.RSA as RSA
import Crypto.PubKey.OpenSsh.Types (OpenSshKeyType(..), OpenSshPublicKey(..),
OpenSshPrivateKey(..))
typeSize :: Int
typeSize = 7
readType :: Monad m => ByteString -> m OpenSshKeyType
readType "ssh-rsa" = return OpenSshKeyTypeRsa
readType "ssh-dss" = return OpenSshKeyTypeDsa
readType _ = fail "Invalid key type"
calculateSize :: Integer -> Int
calculateSize = go 1
where
go i n | 2 ^ (i * 8) > n = i
| otherwise = go (i + 1) n
getInteger :: Get Integer
getInteger = do
size <- fmap fromIntegral getWord32be
ints <- fmap reverse $ replicateM size $ fmap toInteger getWord8
return $ fst $ flip foldl1 (zip ints ([0..] :: [Integer])) $
\(a, _) (c, p) -> (c * (256 ^ p) + a, p)
getOpenSshPublicKey :: Get (ByteString -> OpenSshPublicKey)
getOpenSshPublicKey = do
size <- fmap fromIntegral $ getWord32be
getBytes size >>= readType >>= \typ -> case typ of
OpenSshKeyTypeRsa -> parseRsa
OpenSshKeyTypeDsa -> parseDsa
where
parseRsa = do
e <- getInteger
n <- getInteger
return $ OpenSshPublicKeyRsa $ RSA.PublicKey (calculateSize n) n e
parseDsa = do
p <- getInteger
q <- getInteger
g <- getInteger
y <- getInteger
return $ OpenSshPublicKeyDsa $ DSA.PublicKey (DSA.Params p g q) y
openSshPublicKeyParser :: Parser OpenSshPublicKey
openSshPublicKeyParser = do
void $ readType =<< take typeSize
void space
b64 <- takeTill isSpace
binary <- either fail return $ Base64.decode b64
partialKey <- either fail return $ runGet getOpenSshPublicKey binary
fmap partialKey commentParser
where
commentParser = void space *> (takeTill $ \c -> isSpace c || isControl c)
<|> return ""
decodePublic :: ByteString -> Either String OpenSshPublicKey
decodePublic = parseOnly openSshPublicKeyParser
decodePrivate :: ByteString -> Either String OpenSshPrivateKey
decodePrivate bs = pemParseBS bs >>= \pems -> case pems of
[] -> Left "Private key not found"
(_:_:_) -> Left "Too many private keys"
[p@(PEM { .. })] -> do
case pemName of
"RSA PRIVATE KEY" -> parseRSA p
"DSA PRIVATE KEY" -> parseDSA p
_ -> Left "Unknown private key type"
where
parseDSA :: PEM -> Either String OpenSshPrivateKey
parseDSA (PEM {..}) =
case decodeASN1' DER pemContent of
Left er -> Left (show er)
Right [ Start Sequence
, IntVal _version
, IntVal params_p
, IntVal params_q
, IntVal params_g
, IntVal public_y
, IntVal private_x
, End Sequence
] -> let private_params = DSA.Params {..}
in Right (OpenSshPrivateKeyDsa ( DSA.PrivateKey {..} )public_y)
Right _ -> Left "Invalid ASN1 stream found in PEM."
parseRSA :: PEM -> Either String OpenSshPrivateKey
parseRSA (PEM {..}) =
case decodeASN1' DER pemContent of
Left er -> Left (show er)
Right [ Start Sequence
, IntVal _version
, IntVal public_n
, IntVal public_e
, IntVal private_d
, IntVal private_p
, IntVal private_q
, IntVal private_dP
, IntVal private_dQ
, IntVal private_qinv
, End Sequence
] -> let public_size = calculateSize public_n
private_pub = RSA.PublicKey { .. }
in Right (OpenSshPrivateKeyRsa (RSA.PrivateKey {..}))
Right _ -> Left "Invalid ASN1 stream found in PEM."