{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE ForeignFunctionInterface #-}

module Crypto.Secp256k1.Recovery
  ( recoverPubKey,
  )
where

import qualified Crypto.Secp256k1.Internal.BaseOps as Secp
import qualified Crypto.Secp256k1.Internal.Context as Secp
import qualified Crypto.Secp256k1.Internal.ForeignTypes as Secp
import Data.ByteString (packCStringLen)
import qualified Data.ByteString as BS
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Foreign
import Foreign.C.Types

-- | Parse a compact ECDSA signature (64 bytes + recovery id).
--
-- Returns: 1 when the signature could be parsed, 0 otherwise
-- Args:    ctx:     a secp256k1 context object
-- Out:     sig:     a pointer to a signature object (RecSig65)
-- In:      input64: a pointer to a 64-byte compact signature
--          recid:   the recovery id (0, 1, 2 or 3)
foreign import capi unsafe "secp256k1_recovery.h secp256k1_ecdsa_recoverable_signature_parse_compact"
  secp256k1_ecdsa_recoverable_signature_parse_compact ::
    Ptr Secp.LCtx ->
    Ptr Secp.RecSig65 ->
    Ptr CUChar ->
    CInt ->
    IO Secp.Ret

-- | Recover an ECDSA public key from a signature.
--
-- Returns: 1: public key successfully recovered (which guarantees a correct signature).
--          0: otherwise.
-- Args:    ctx:       pointer to a context object, initialized for verification
-- Out:     pubkey:    pointer to the recovered public key
-- In:      sig:       pointer to initialized signature that supports pubkey recovery
--          msg32:     the 32-byte message hash assumed to be signed
foreign import capi unsafe "secp256k1_recovery.h secp256k1_ecdsa_recover"
  secp256k1_ecdsa_recover ::
    Ptr Secp.LCtx ->
    Ptr Secp.PubKey64 ->
    Ptr Secp.RecSig65 ->
    Ptr Secp.Msg32 ->
    IO Secp.Ret

-- | Recover a public key from an Ethereum-style signature
-- The signature should be 65 bytes: r (32) || s (32) || v (1)
-- The message should be 32 bytes (typically a Keccak256 hash)
recoverPubKey :: BS.ByteString -> BS.ByteString -> IO (Maybe BS.ByteString)
recoverPubKey message sig
  | BS.length sig /= 65 = pure Nothing
  | BS.length message /= 32 = pure Nothing
  | otherwise = do
      let r_s = BS.take 64 sig -- First 64 bytes (r || s)
          vByte = BS.index sig 64 -- Recovery ID byte
          -- Normalize v to 0-3 range (Ethereum uses 27/28, secp256k1 expects 0-3)
          v = fromIntegral $ if vByte >= 27 then vByte - 27 else vByte

      -- Allocate space for recoverable signature (65 bytes)
      allocaBytes 65 $ \recSigPtr ->
        -- Allocate space for public key (64 bytes internal representation)
        allocaBytes 64 $ \pubKeyPtr ->
          unsafeUseAsCStringLen r_s $ \(compactPtr, _) ->
            unsafeUseAsCStringLen message $ \(msgPtr, _) ->
              Secp.withContext $ \(Secp.Ctx ctxFPtr) ->
                withForeignPtr ctxFPtr $ \ctxPtr -> do
                  -- Parse the compact signature with recovery ID
                  ret1 <-
                    secp256k1_ecdsa_recoverable_signature_parse_compact
                      ctxPtr
                      (castPtr recSigPtr)
                      (castPtr compactPtr)
                      v

                  if not (Secp.isSuccess ret1)
                    then return Nothing
                    else do
                      -- Recover the public key
                      ret2 <-
                        secp256k1_ecdsa_recover
                          ctxPtr
                          (castPtr pubKeyPtr)
                          (castPtr recSigPtr)
                          (castPtr msgPtr)

                      if not (Secp.isSuccess ret2)
                        then return Nothing
                        else do
                          -- Serialize the public key (uncompressed format, 65 bytes)
                          allocaBytes 65 $ \outPtr ->
                            alloca $ \sizePtr -> do
                              poke sizePtr 65 -- Maximum size for uncompressed key
                              ret3 <-
                                Secp.ecPubKeySerialize
                                  ctxPtr
                                  outPtr
                                  sizePtr
                                  (castPtr pubKeyPtr)
                                  Secp.uncompressed

                              if not (Secp.isSuccess ret3)
                                then return Nothing
                                else do
                                  size <- peek sizePtr
                                  pubKeyBytes <- packCStringLen (castPtr outPtr, fromIntegral size)
                                  return $ Just pubKeyBytes
