-- Copyright (C) 2010 John Millikin <jmillikin@gmail.com>
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU General Public License as published by
-- the Free Software Foundation, either version 3 of the License, or
-- any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.

module Network.Protocol.SASL.GNU
        (
        -- * Library Information
          headerVersion
        , libraryVersion
        , checkVersion

        -- * SASL Contexts
        , SASL
        , runSASL
        , setCallback
        , runCallback

        -- * Mechanisms
        , Mechanism (..)
        , clientMechanisms
        , clientSupports
        , clientSuggestMechanism
        , serverMechanisms
        , serverSupports

        -- * SASL Sessions
        , Session
        , runClient
        , runServer
        , mechanismName

        -- ** Session Properties
        , Property (..)
        , setProperty
        , getProperty
        , getPropertyFast

        -- ** Session IO
        , Progress (..)
        , step
        , step64
        , encode
        , decode

        -- ** Error handling
        , Error (..)
        , catch
        , handle
        , try
        , throw

        -- * Bundled codecs
        , toBase64
        , fromBase64
        , md5
        , sha1
        , hmacMD5
        , hmacSHA1
        , nonce
        , random
        ) where

-- Imports {{{

import           Prelude hiding (catch)
import           Data.Maybe (fromMaybe)
import           Control.Applicative (Applicative, pure, (<*>), (<$>))
import qualified Control.Exception as E
import           Control.Monad (ap, when, unless, (<=<))
import           Control.Monad.Loops (unfoldrM)
import           Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Control.Monad.Trans.Reader as R
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Char8 as Char8
import           Data.Char (isDigit)
import           Data.String (IsString, fromString)
import qualified Foreign as F
import qualified Foreign.C as F
import           System.IO.Unsafe (unsafePerformIO)
import qualified Text.ParserCombinators.ReadP as P

-- }}}

-- Library Information {{{

-- | Which version of @gsasl.h@ this module was compiled against
headerVersion :: (Integer, Integer, Integer)
headerVersion = (major, minor, patch) where
        major = toInteger hsgsasl_VERSION_MAJOR
        minor = toInteger hsgsasl_VERSION_MINOR
        patch = toInteger hsgsasl_VERSION_PATCH

-- | Which version of @libgsasl.so@ is loaded
libraryVersion :: IO (Integer, Integer, Integer)
libraryVersion = io where
        parseVersion str = case P.readP_to_S parser str of
                [] -> Nothing
                ((parsed, _):_) -> Just parsed
        parser = do
                majorS <- P.munch1 isDigit
                _ <- P.char '.'
                minorS <- P.munch1 isDigit
                _ <- P.char '.'
                patchS <- P.munch1 isDigit
                eof
                return (read majorS, read minorS, read patchS)
        io = do
                cstr <- gsasl_check_version F.nullPtr
                maybeStr <- F.maybePeek F.peekCString cstr
                return $ fromMaybe (error $ "Invalid version string: " ++ show maybeStr)
                        (maybeStr >>= parseVersion)

        eof = do
                s <- P.look
                unless (null s) P.pfail

-- | Whether the header and library versions are compatible
checkVersion :: IO Bool
checkVersion = fmap (== 1) hsgsasl_check_version

-- }}}

-- SASL Contexts {{{

newtype Context = Context (F.Ptr Context)
newtype SASL a = SASL { unSASL :: R.ReaderT Context IO a }

instance Functor SASL where
        fmap f = SASL . fmap f . unSASL

instance Applicative SASL where
        pure = SASL . pure
        (<*>) = ap

instance Monad SASL where
        return = SASL . return
        (>>=) sasl f = SASL $ unSASL sasl >>= unSASL . f

instance MonadIO SASL where
        liftIO = SASL . liftIO

-- TODO: more instances

runSASL :: SASL a -> IO a
runSASL = withContext . R.runReaderT . unSASL

withContext :: (Context -> IO a) -> IO a
withContext = E.bracket newContext freeContext where
        newContext = F.alloca $ \pCtxt -> do
                gsasl_init pCtxt >>= checkRC
                Context `fmap` F.peek pCtxt
        freeContext (Context ctx) = do
                hook <- gsasl_callback_hook_get ctx
                gsasl_done ctx
                freeCallbackHook hook

getContext :: SASL (F.Ptr Context)
getContext = SASL $ do
        Context ptr <- R.ask
        return ptr

bracketSASL :: (F.Ptr Context -> IO a) -> (a -> IO b) -> (a -> IO c) -> SASL c
bracketSASL before after thing = do
        ctx <- getContext
        liftIO $ E.bracket (before ctx) after thing

-- }}}

-- Mechanisms {{{

newtype Mechanism = Mechanism B.ByteString
        deriving (Show, Eq)

instance IsString Mechanism where
        fromString = Mechanism . fromString

-- | A list of 'Mechanism's supported by the @libgsasl@ client.
clientMechanisms :: SASL [Mechanism]
clientMechanisms = bracketSASL io gsasl_free splitMechListPtr where
        io ctx = F.alloca $ \pStr -> do
                gsasl_client_mechlist ctx pStr >>= checkRC
                F.peek pStr

-- | Whether there is client-side support for a specified 'Mechanism'.
clientSupports :: Mechanism -> SASL Bool
clientSupports (Mechanism name) = do
        ctx <- getContext
        liftIO $ B.useAsCString name $ \pName -> do
                cres <- gsasl_client_support_p ctx pName
                return $ cres == 1

-- | Given a list of 'Mechanism's, suggest which to use (or 'Nothing' if
-- no supported 'Mechanism' is found).
clientSuggestMechanism :: [Mechanism] -> SASL (Maybe Mechanism)
clientSuggestMechanism mechs = do
        let bytes = B.intercalate (Char8.pack " ") [x | Mechanism x <- mechs]
        ctx <- getContext
        liftIO $ B.useAsCString bytes $
                F.maybePeek (fmap Mechanism . B.packCString) <=<
                gsasl_client_suggest_mechanism ctx

-- | A list of 'Mechanism's supported by the @libgsasl@ server.
serverMechanisms :: SASL [Mechanism]
serverMechanisms = bracketSASL io gsasl_free splitMechListPtr where
        io ctx = F.alloca $ \pStr -> do
                gsasl_server_mechlist ctx pStr >>= checkRC
                F.peek pStr

-- | Whether there is server-side support for a specified 'Mechanism'.
serverSupports :: Mechanism -> SASL Bool
serverSupports (Mechanism name) = do
        ctx <- getContext
        liftIO $ B.useAsCString name $ \pName -> do
                cres <- gsasl_server_support_p ctx pName
                return $ cres == 1

splitMechListPtr :: F.CString -> IO [Mechanism]
splitMechListPtr ptr = unfoldrM step' (ptr, ptr, 0, True) where
        step' (_, _, _, False) = return Nothing
        step' (p_0, p_i, i, _) = F.peek p_i >>= \chr -> let
                p_i' = F.plusPtr p_i 1
                peek continue = if i == 0
                        then step' (p_i', p_i', 0, continue)
                        else do
                                bytes <- B.packCStringLen (p_0, i)
                                return $ Just (Mechanism bytes, (p_i', p_i', 0, continue))
                in case chr of
                        0x00 -> peek False
                        0x20 -> peek True
                        _    -> step' (p_0, p_i', i + 1, True)

-- }}}

-- SASL Sessions {{{

newtype SessionCtx = SessionCtx (F.Ptr SessionCtx)
newtype Session a = Session { unSession :: R.ReaderT SessionCtx IO a }

instance Functor Session where
        fmap f = Session . fmap f . unSession

instance Applicative Session where
        pure = Session . pure
        (<*>) = ap

instance Monad Session where
        return = Session . return
        (>>=) m f = Session $ unSession m >>= unSession . f

instance MonadIO Session where
        liftIO = Session . liftIO

type SessionProc = F.Ptr Context -> F.CString -> F.Ptr (F.Ptr SessionCtx) -> IO F.CInt

runSession :: SessionProc -> Mechanism -> Session a -> SASL (Either Error a)
runSession start (Mechanism mech) session = bracketSASL newSession freeSession io where
        newSession ctx =
                B.useAsCString mech $ \pMech ->
                F.alloca $ \pSessionCtx -> E.handle noSession $ do
                start ctx pMech pSessionCtx >>= checkRC
                Right . SessionCtx <$> F.peek pSessionCtx
        noSession (SASLException err) = return $ Left err

        freeSession (Left _) = return ()
        freeSession (Right (SessionCtx ptr)) = gsasl_finish ptr

        io (Left err) = return $ Left err
        io (Right sctx) = E.catch
                (Right <$> R.runReaderT (unSession session) sctx)
                (\(SASLException err) -> return $ Left err)

-- | Run a session using the @libgsasl@ client.
runClient :: Mechanism -> Session a -> SASL (Either Error a)
runClient = runSession gsasl_client_start

-- | Run a session using the @libgsasl@ server.
runServer :: Mechanism -> Session a -> SASL (Either Error a)
runServer = runSession gsasl_server_start

getSessionContext :: Session (F.Ptr SessionCtx)
getSessionContext = Session $ do
        SessionCtx sctx <- R.ask
        return sctx

-- | The name of the session's SASL mechanism.
mechanismName :: Session Mechanism
mechanismName = do
        sctx <- getSessionContext
        liftIO $ do
                cstr <- gsasl_mechanism_name sctx
                Mechanism `fmap` B.packCString cstr

bracketSession :: (F.Ptr SessionCtx -> IO a) -> (a -> IO b) -> (a -> IO c) -> Session c
bracketSession before after thing = do
        sctx <- getSessionContext
        liftIO $ E.bracket (before sctx) after thing

-- }}}

-- Error handling {{{

data Error
        = UnknownMechanism
        | MechanismCalledTooManyTimes
        | MallocError
        | Base64Error
        | CryptoError
        | SASLPrepError
        | MechanismParseError
        | AuthenticationError
        | IntegrityError
        | NoClientCode
        | NoServerCode
        | NoCallback
        | NoAnonymousToken
        | NoAuthID
        | NoAuthzID
        | NoPassword
        | NoPasscode
        | NoPIN
        | NoService
        | NoHostname

        | GSSAPI_ReleaseBufferError
        | GSSAPI_ImportNameError
        | GSSAPI_InitSecContextError
        | GSSAPI_AcceptSecContextError
        | GSSAPI_UnwrapError
        | GSSAPI_WrapError
        | GSSAPI_AquireCredError
        | GSSAPI_DisplayNameError
        | GSSAPI_UnsupportedProtectionError
        | GSSAPI_EncapsulateTokenError
        | GSSAPI_DecapsulateTokenError
        | GSSAPI_InquireMechForSASLNameError
        | GSSAPI_TestOIDSetMemberError
        | GSSAPI_ReleaseOIDSetError

        | KerberosV5_InitError
        | KerberosV5_InternalError

        | SecurID_ServerNeedAdditionalPasscode
        | SecurID_ServerNeedNewPIN

instance Show Error where
        show = strError

-- | Convert an error code to a human-readable string explanation for the
-- particular error code.
--
-- This string can be used to output a diagnostic message to the user.
strError :: Error -> String
strError err = unsafePerformIO $ gsasl_strerror (cFromError err) >>= F.peekCString

newtype SASLException = SASLException Error deriving (Show)

instance E.Exception SASLException

cFromError :: Error -> F.CInt
cFromError e = case e of
        UnknownMechanism -> 2
        MechanismCalledTooManyTimes -> 3
        MallocError -> 7
        Base64Error -> 8
        CryptoError -> 9
        SASLPrepError -> 29
        MechanismParseError -> 30
        AuthenticationError -> 31
        IntegrityError -> 33
        NoClientCode -> 35
        NoServerCode -> 36
        NoCallback -> 51
        NoAnonymousToken -> 52
        NoAuthID -> 53
        NoAuthzID -> 54
        NoPassword -> 55
        NoPasscode -> 56
        NoPIN -> 57
        NoService -> 58
        NoHostname -> 59
        GSSAPI_ReleaseBufferError -> 37
        GSSAPI_ImportNameError -> 38
        GSSAPI_InitSecContextError -> 39
        GSSAPI_AcceptSecContextError -> 40
        GSSAPI_UnwrapError -> 41
        GSSAPI_WrapError -> 42
        GSSAPI_AquireCredError -> 43
        GSSAPI_DisplayNameError -> 44
        GSSAPI_UnsupportedProtectionError -> 45
        GSSAPI_EncapsulateTokenError -> 60
        GSSAPI_DecapsulateTokenError -> 61
        GSSAPI_InquireMechForSASLNameError -> 62
        GSSAPI_TestOIDSetMemberError -> 63
        GSSAPI_ReleaseOIDSetError -> 64
        KerberosV5_InitError -> 46
        KerberosV5_InternalError -> 47
        SecurID_ServerNeedAdditionalPasscode -> 48
        SecurID_ServerNeedNewPIN -> 49

cToError :: F.CInt -> Error
cToError x = case x of
        2 -> UnknownMechanism
        3 -> MechanismCalledTooManyTimes
        7 -> MallocError
        8 -> Base64Error
        9 -> CryptoError
        29 -> SASLPrepError
        30 -> MechanismParseError
        31 -> AuthenticationError
        33 -> IntegrityError
        35 -> NoClientCode
        36 -> NoServerCode
        51 -> NoCallback
        52 -> NoAnonymousToken
        53 -> NoAuthID
        54 -> NoAuthzID
        55 -> NoPassword
        56 -> NoPasscode
        57 -> NoPIN
        58 -> NoService
        59 -> NoHostname
        37 -> GSSAPI_ReleaseBufferError
        38 -> GSSAPI_ImportNameError
        39 -> GSSAPI_InitSecContextError
        40 -> GSSAPI_AcceptSecContextError
        41 -> GSSAPI_UnwrapError
        42 -> GSSAPI_WrapError
        43 -> GSSAPI_AquireCredError
        44 -> GSSAPI_DisplayNameError
        45 -> GSSAPI_UnsupportedProtectionError
        60 -> GSSAPI_EncapsulateTokenError
        61 -> GSSAPI_DecapsulateTokenError
        62 -> GSSAPI_InquireMechForSASLNameError
        63 -> GSSAPI_TestOIDSetMemberError
        64 -> GSSAPI_ReleaseOIDSetError
        46 -> KerberosV5_InitError
        47 -> KerberosV5_InternalError
        48 -> SecurID_ServerNeedAdditionalPasscode
        49 -> SecurID_ServerNeedNewPIN
        _ -> error $ "Unknown GNU SASL return code: " ++ show x

throw :: Error -> Session a
throw = liftIO . E.throwIO . SASLException

catch :: Session a -> (Error -> Session a) -> Session a
catch m f = do
        sctx <- SessionCtx `fmap` getSessionContext
        Session . liftIO $ E.catch
                (R.runReaderT (unSession m) sctx)
                (\(SASLException err) -> R.runReaderT (unSession (f err)) sctx)

handle :: (Error -> Session a) -> Session a -> Session a
handle = flip catch

try :: Session a -> Session (Either Error a)
try m = catch (fmap Right m) (return . Left)

-- }}}

-- Session Properties {{{

data Property
        = PropertyAuthID
        | PropertyAuthzID
        | PropertyPassword
        | PropertyAnonymousToken
        | PropertyService
        | PropertyHostname
        | PropertyGSSAPIDisplayName
        | PropertyPasscode
        | PropertySuggestedPIN
        | PropertyPIN
        | PropertyRealm
        | PropertyDigestMD5HashedPassword
        | PropertyQOPS
        | PropertyQOP
        | PropertyScramIter
        | PropertyScramSalt
        | PropertyScramSaltedPassword

        | ValidateSimple
        | ValidateExternal
        | ValidateAnonymous
        | ValidateGSSAPI
        | ValidateSecurID
        deriving (Show, Eq)

cFromProperty :: Property -> F.CInt
cFromProperty x = case x of
        PropertyAuthID -> 1
        PropertyAuthzID -> 2
        PropertyPassword -> 3
        PropertyAnonymousToken -> 4
        PropertyService -> 5
        PropertyHostname -> 6
        PropertyGSSAPIDisplayName -> 7
        PropertyPasscode -> 8
        PropertySuggestedPIN -> 9
        PropertyPIN -> 10
        PropertyRealm -> 11
        PropertyDigestMD5HashedPassword -> 12
        PropertyQOPS -> 13
        PropertyQOP -> 14
        PropertyScramIter -> 15
        PropertyScramSalt -> 16
        PropertyScramSaltedPassword -> 17
        ValidateSimple -> 500
        ValidateExternal -> 501
        ValidateAnonymous -> 502
        ValidateGSSAPI -> 503
        ValidateSecurID -> 504

cToProperty :: F.CInt -> Property
cToProperty x = case x of
        1 -> PropertyAuthID
        2 -> PropertyAuthzID
        3 -> PropertyPassword
        4 -> PropertyAnonymousToken
        5 -> PropertyService
        6 -> PropertyHostname
        7 -> PropertyGSSAPIDisplayName
        8 -> PropertyPasscode
        9 -> PropertySuggestedPIN
        10 -> PropertyPIN
        11 -> PropertyRealm
        12 -> PropertyDigestMD5HashedPassword
        13 -> PropertyQOPS
        14 -> PropertyQOP
        15 -> PropertyScramIter
        16 -> PropertyScramSalt
        17 -> PropertyScramSaltedPassword
        500 -> ValidateSimple
        501 -> ValidateExternal
        502 -> ValidateAnonymous
        503 -> ValidateGSSAPI
        504 -> ValidateSecurID
        _   -> error $ "Unknown GNU SASL property code: " ++ show x

-- | Store some data in the session for the given property. The data must
-- be @NULL@-terminated.
setProperty :: Property -> B.ByteString -> Session ()
setProperty prop bytes = do
        sctx <- getSessionContext
        liftIO $
                B.useAsCString bytes $
                gsasl_property_set sctx (cFromProperty prop)

-- | Retrieve the data stored in the session for the given property,
-- possibly invoking the current callback to get the value.
getProperty :: Property -> Session (Maybe B.ByteString)
getProperty prop = do
        sctx <- getSessionContext
        liftIO $ do
                cstr <- gsasl_property_get sctx (cFromProperty prop)
                if cstr /= F.nullPtr
                        then Just <$> B.packCString cstr
                        else do
                                liftIO $ checkCallbackException sctx
                                return Nothing

-- | Retrieve the data stored in the session for the given property,
-- without invoking the current callback.
getPropertyFast :: Property -> Session (Maybe B.ByteString)
getPropertyFast prop = do
        sctx <- getSessionContext
        liftIO $
                gsasl_property_fast sctx (cFromProperty prop) >>=
                F.maybePeek B.packCString

-- }}}

-- Callbacks {{{

type CallbackFn = F.Ptr Context -> F.Ptr SessionCtx -> F.CInt -> IO F.CInt
data CallbackHook = CallbackHook (F.FunPtr CallbackFn) (Property -> Session Progress)

newCallbackHook :: (Property -> Session Progress) -> IO (F.Ptr CallbackHook, F.FunPtr CallbackFn)
newCallbackHook cb = E.bracketOnError
        (wrapCallbackImpl (callbackImpl cb))
        F.freeHaskellFunPtr
        (\funPtr -> let hook = CallbackHook funPtr cb in E.bracketOnError
                (F.newStablePtr hook)
                F.freeStablePtr
                (\stablePtr -> let
                        hookPtr = F.castPtr (F.castStablePtrToPtr stablePtr)
                        in return (hookPtr, funPtr)))

freeCallbackHook :: F.Ptr CallbackHook -> IO ()
freeCallbackHook ptr = unless (ptr == F.nullPtr) $ do
        let stablePtr = F.castPtrToStablePtr $ F.castPtr ptr
        hook <- F.deRefStablePtr stablePtr
        F.freeStablePtr stablePtr
        let (CallbackHook funPtr _) = hook
        F.freeHaskellFunPtr funPtr

callbackImpl :: (Property -> Session Progress) -> CallbackFn
callbackImpl cb _ sctx cProp = let
        globalIO = error "globalIO is not implemented"

        sessionIO = do
                let session = cb $ cToProperty cProp
                cFromProgress <$> R.runReaderT (unSession session) (SessionCtx sctx)

        onError :: SASLException -> IO F.CInt
        onError (SASLException err) = return $ cFromError err

        onException :: E.SomeException -> IO F.CInt
        onException exc = do
                -- A bit ugly; session hooks aren't used anywhere else in
                -- the binding, so the exception is stashed here.
                stablePtr <- F.newStablePtr exc
                gsasl_session_hook_set sctx $ F.castStablePtrToPtr stablePtr

                -- standard libgsasl return codes are all >= 0, so using -1
                -- provides an easy way to determine later whether the
                -- exception came from Haskell code.
                return (-1)

        catchErrors io = E.catches io [E.Handler onError, E.Handler onException]

        in catchErrors $ if sctx == F.nullPtr then globalIO else sessionIO

foreign import ccall "wrapper"
        wrapCallbackImpl :: CallbackFn -> IO (F.FunPtr CallbackFn)

-- Used to check whether a callback threw an exception
checkCallbackException :: F.Ptr SessionCtx -> IO ()
checkCallbackException sctx = do
        hook <- gsasl_session_hook_get sctx
        when (hook /= F.nullPtr) $ do
                let stable = F.castPtrToStablePtr hook
                exc <- F.deRefStablePtr stable
                F.freeStablePtr stable
                E.throwIO (exc :: E.SomeException)

-- | Set the current SASL callback. The callback will be used by mechanisms
-- to discover various parameters, such as usernames and passwords.
setCallback :: (Property -> Session Progress) -> SASL ()
setCallback cb = do
        ctx <- getContext
        liftIO $ do
                freeCallbackHook =<< gsasl_callback_hook_get ctx
                (hook, cbPtr) <- newCallbackHook cb
                gsasl_callback_hook_set ctx hook
                gsasl_callback_set ctx cbPtr

-- | Run the current callback; the property indicates what action the
-- callback is expected to perform.
runCallback :: Property -> Session Progress
runCallback prop = do
        -- This is a bit evil; the first field in Gsasl_session is a Gsasl context,
        -- so it's safe to cast here (assuming they never change the layout).
        ctx <- fmap F.castPtr getSessionContext
        hookPtr <- liftIO $ gsasl_callback_hook_get ctx
        when (hookPtr == F.nullPtr) $ throw NoCallback
        hook <- liftIO $ F.deRefStablePtr $ F.castPtrToStablePtr hookPtr
        let (CallbackHook _ cb) = hook
        cb prop

-- }}}

-- Session IO {{{

data Progress = Complete | NeedsMore
        deriving (Show, Eq)

cFromProgress :: Progress -> F.CInt
cFromProgress x = case x of
        Complete -> 0
        NeedsMore -> 1

-- | Perform one step of SASL authentication. This reads data from the other
-- end, processes it (potentially running the callback), and returns data
-- to be sent back.
--
-- Also returns 'NeedsMore' if authentication is not yet complete.
step :: B.ByteString -> Session (B.ByteString, Progress)
step input = bracketSession get free peek where
        get sctx =
                B.unsafeUseAsCStringLen input $ \(pInput, inputLen) ->
                F.alloca $ \pOutput ->
                F.alloca $ \pOutputLen -> do
                rc <- gsasl_step sctx pInput (fromIntegral inputLen) pOutput pOutputLen
                when (rc /= 0) $ checkCallbackException sctx
                progress <- checkStepRC rc
                cstrLen <- F.peek pOutputLen
                cstr <- F.peek pOutput
                return (cstr, cstrLen, progress)

        free (cstr, _, _) = gsasl_free cstr
        peek (cstr, cstrLen, progress) = do
                output <- B.packCStringLen (cstr, fromIntegral cstrLen)
                return (output, progress)

-- | A simple wrapper around 'step' which uses base64 to decode the input
-- and encode the output.
step64 :: B.ByteString -> Session (B.ByteString, Progress)
step64 input = bracketSession get free peek where
        get sctx =
                B.useAsCString input $ \pInput ->
                F.alloca $ \pOutput -> do
                rc <- gsasl_step64 sctx pInput pOutput
                when (rc /= 0) $ checkCallbackException sctx
                progress <- checkStepRC rc
                cstr <- F.peek pOutput
                return (cstr, progress)

        free (cstr, _) = gsasl_free cstr
        peek (cstr, progress) = do
                output <- B.packCString cstr
                return (output, progress)

checkStepRC :: F.CInt -> IO Progress
checkStepRC x = case x of
        0 -> return Complete
        1 -> return NeedsMore
        _ -> E.throwIO (SASLException (cToError x))

encodeDecodeHelper :: (F.Storable a, Integral a, Num t) =>
           (F.Ptr SessionCtx -> F.Ptr F.CChar -> t -> F.Ptr (F.Ptr F.CChar) -> F.Ptr a -> IO F.CInt)
        -> B.ByteString
        -> Session B.ByteString
encodeDecodeHelper f input = do
        sctx <- getSessionContext
        liftIO $
                B.unsafeUseAsCStringLen input $ \(cstr, cstrLen) ->
                F.alloca $ \pOutput ->
                F.alloca $ \pOutputLen -> do
                        rc <- f sctx cstr (fromIntegral cstrLen) pOutput pOutputLen
                        when (rc /= 0) $ checkCallbackException sctx
                        checkRC rc
                        output <- F.peek pOutput
                        outputLen <- fromIntegral <$> F.peek pOutputLen
                        outputBytes <- B.packCStringLen (output, outputLen)
                        gsasl_free output
                        return outputBytes

-- | Encode data according to the negotiated SASL mechanism. This might mean
-- the data is integrity or privacy protected.
encode :: B.ByteString -> Session B.ByteString
encode = encodeDecodeHelper gsasl_encode

-- | Decode data according to the negotiated SASL mechanism. This might mean
-- the data is integrity or privacy protected.
decode :: B.ByteString -> Session B.ByteString
decode = encodeDecodeHelper gsasl_decode

-- }}}

-- Bundled codecs {{{

base64Helper :: (F.Storable a, Integral a, Num t) =>
           (F.Ptr F.CChar -> t -> F.Ptr (F.Ptr F.CChar) -> F.Ptr a -> IO F.CInt)
        -> B.ByteString
        -> B.ByteString
base64Helper f input = unsafePerformIO $
        B.unsafeUseAsCStringLen input $ \(pIn, inLen) ->
        F.alloca $ \pOut ->
        F.alloca $ \pOutLen -> do
        f pIn (fromIntegral inLen) pOut pOutLen >>= checkRC
        outLen <- F.peek pOutLen
        outPtr <- F.peek pOut
        B.packCStringLen (outPtr, fromIntegral outLen)

toBase64 :: B.ByteString -> B.ByteString
toBase64 = base64Helper gsasl_base64_to

fromBase64 :: B.ByteString -> B.ByteString
fromBase64 = base64Helper gsasl_base64_from

md5 :: B.ByteString -> B.ByteString
md5 input = unsafePerformIO $
        B.unsafeUseAsCStringLen input $ \(pIn, inLen) ->
        F.alloca $ \pOut ->
        F.allocaBytes 16 $ \outBuf -> do
        F.poke pOut outBuf
        gsasl_md5 pIn (fromIntegral inLen) pOut >>= checkRC
        B.packCStringLen (outBuf, 16)

sha1 :: B.ByteString -> B.ByteString
sha1 input = unsafePerformIO $
        B.unsafeUseAsCStringLen input $ \(pIn, inLen) ->
        F.alloca $ \pOut -> do
        gsasl_sha1 pIn (fromIntegral inLen) pOut >>= checkRC
        outBuf <- F.peek pOut
        ret <- B.packCStringLen (outBuf, 20)
        F.free outBuf
        return ret


hmacMD5 :: B.ByteString -- ^ Key
        -> B.ByteString -- ^ Input data
        -> B.ByteString
hmacMD5 key input = unsafePerformIO $
        B.unsafeUseAsCStringLen key $ \(pKey, keyLen) ->
        B.unsafeUseAsCStringLen input $ \(pIn, inLen) ->
        F.alloca $ \pOut ->
        F.allocaBytes 16 $ \outBuf -> do
        F.poke pOut outBuf
        gsasl_hmac_md5 pKey (fromIntegral keyLen) pIn (fromIntegral inLen) pOut >>= checkRC
        B.packCStringLen (outBuf, 16)

hmacSHA1 :: B.ByteString -- ^ Key
         -> B.ByteString -- ^ Input data
         -> B.ByteString
hmacSHA1 key input = unsafePerformIO $
        B.unsafeUseAsCStringLen key $ \(pKey, keyLen) ->
        B.unsafeUseAsCStringLen input $ \(pIn, inLen) ->
        F.alloca $ \pOut ->
        F.allocaBytes 20 $ \outBuf -> do
        F.poke pOut outBuf
        gsasl_hmac_sha1 pKey (fromIntegral keyLen) pIn (fromIntegral inLen) pOut >>= checkRC
        B.packCStringLen (outBuf, 20)

-- | Returns unpredictable data of a given size
nonce :: Integer -> IO B.ByteString
nonce size = F.allocaBytes (fromInteger size) $ \buf -> do
        gsasl_nonce buf (fromIntegral size) >>= checkRC
        B.packCStringLen (buf, fromIntegral size)

-- | Returns cryptographically strong random data of a given size
random :: Integer -> IO B.ByteString
random size = F.allocaBytes (fromInteger size) $ \buf -> do
        gsasl_random buf (fromIntegral size) >>= checkRC
        B.packCStringLen (buf, fromIntegral size)


-- }}}

-- Miscellaneous {{{

checkRC :: F.CInt -> IO ()
checkRC x = case x of
        0 -> return ()
        _ -> E.throwIO (SASLException (cToError x))

-- }}}

-- FFI imports {{{

foreign import ccall "hsgsasl_VERSION_MAJOR"
        hsgsasl_VERSION_MAJOR :: F.CInt

foreign import ccall "hsgsasl_VERSION_MINOR"
        hsgsasl_VERSION_MINOR :: F.CInt

foreign import ccall "hsgsasl_VERSION_PATCH"
        hsgsasl_VERSION_PATCH :: F.CInt

foreign import ccall "hsgsasl_check_version"
        hsgsasl_check_version :: IO F.CInt

foreign import ccall "gsasl.h gsasl_init"
        gsasl_init :: F.Ptr (F.Ptr Context) -> IO F.CInt

foreign import ccall "gsasl.h gsasl_done"
        gsasl_done :: F.Ptr Context -> IO ()

foreign import ccall "gsasl.h gsasl_check_version"
        gsasl_check_version :: F.CString -> IO F.CString

foreign import ccall "gsasl.h gsasl_callback_set"
        gsasl_callback_set :: F.Ptr Context -> F.FunPtr CallbackFn -> IO ()

foreign import ccall "gsasl.h gsasl_callback_hook_get"
        gsasl_callback_hook_get :: F.Ptr Context -> IO (F.Ptr a)

foreign import ccall "gsasl.h gsasl_callback_hook_set"
        gsasl_callback_hook_set :: F.Ptr Context -> F.Ptr a -> IO ()

foreign import ccall "gsasl.h gsasl_session_hook_get"
        gsasl_session_hook_get :: F.Ptr SessionCtx -> IO (F.Ptr a)

foreign import ccall "gsasl.h gsasl_session_hook_set"
        gsasl_session_hook_set :: F.Ptr SessionCtx -> F.Ptr a -> IO ()

foreign import ccall "gsasl.h gsasl_property_set"
        gsasl_property_set :: F.Ptr SessionCtx -> F.CInt -> F.CString -> IO ()

foreign import ccall safe "gsasl.h gsasl_property_get"
        gsasl_property_get :: F.Ptr SessionCtx -> F.CInt -> IO F.CString

foreign import ccall "gsasl.h gsasl_property_fast"
        gsasl_property_fast :: F.Ptr SessionCtx -> F.CInt -> IO F.CString

foreign import ccall "gsasl.h gsasl_client_mechlist"
        gsasl_client_mechlist :: F.Ptr Context -> F.Ptr F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_client_support_p"
        gsasl_client_support_p :: F.Ptr Context -> F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_client_suggest_mechanism"
        gsasl_client_suggest_mechanism :: F.Ptr Context -> F.CString -> IO F.CString

foreign import ccall "gsasl.h gsasl_server_mechlist"
        gsasl_server_mechlist :: F.Ptr Context -> F.Ptr F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_server_support_p"
        gsasl_server_support_p :: F.Ptr Context -> F.CString -> IO F.CInt

foreign import ccall safe "gsasl.h gsasl_client_start"
        gsasl_client_start :: SessionProc

foreign import ccall safe "gsasl.h gsasl_server_start"
        gsasl_server_start :: SessionProc

foreign import ccall safe "gsasl.h gsasl_step"
        gsasl_step :: F.Ptr SessionCtx -> F.CString -> F.CSize -> F.Ptr F.CString -> F.Ptr F.CSize -> IO F.CInt

foreign import ccall safe "gsasl.h gsasl_step64"
        gsasl_step64 :: F.Ptr SessionCtx -> F.CString -> F.Ptr F.CString -> IO F.CInt

foreign import ccall safe "gsasl.h gsasl_finish"
        gsasl_finish :: F.Ptr SessionCtx -> IO ()

foreign import ccall safe "gsasl.h gsasl_encode"
        gsasl_encode :: F.Ptr SessionCtx -> F.CString -> F.CSize -> F.Ptr F.CString -> F.Ptr F.CSize -> IO F.CInt

foreign import ccall safe "gsasl.h gsasl_decode"
        gsasl_decode :: F.Ptr SessionCtx -> F.CString -> F.CSize -> F.Ptr F.CString -> F.Ptr F.CSize -> IO F.CInt

foreign import ccall "gsasl.h gsasl_mechanism_name"
        gsasl_mechanism_name :: F.Ptr SessionCtx -> IO F.CString

foreign import ccall "gsasl.h gsasl_strerror"
        gsasl_strerror :: F.CInt -> IO F.CString

foreign import ccall "gsasl.h gsasl_base64_to"
        gsasl_base64_to :: F.CString -> F.CSize -> F.Ptr F.CString -> F.Ptr F.CSize -> IO F.CInt

foreign import ccall "gsasl.h gsasl_base64_from"
        gsasl_base64_from :: F.CString -> F.CSize -> F.Ptr F.CString -> F.Ptr F.CSize -> IO F.CInt

foreign import ccall "gsasl.h gsasl_md5"
        gsasl_md5 :: F.CString -> F.CSize -> F.Ptr F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_sha1"
        gsasl_sha1 :: F.CString -> F.CSize -> F.Ptr F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_hmac_md5"
        gsasl_hmac_md5 :: F.CString -> F.CSize -> F.CString -> F.CSize -> F.Ptr F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_hmac_sha1"
        gsasl_hmac_sha1 :: F.CString -> F.CSize -> F.CString -> F.CSize -> F.Ptr F.CString -> IO F.CInt

foreign import ccall "gsasl.h gsasl_nonce"
        gsasl_nonce :: F.CString -> F.CSize -> IO F.CInt

foreign import ccall "gsasl.h gsasl_random"
        gsasl_random :: F.CString -> F.CSize -> IO F.CInt

foreign import ccall "gsasl.h gsasl_free"
        gsasl_free :: F.Ptr a -> IO ()

-- }}}