{-# LANGUAGE ForeignFunctionInterface #-}

-- |
-- Module      : Crypto.Cipher.XSalsa
-- License     : BSD-style
-- Maintainer  : Brandon Hamilton <brandon.hamilton@gmail.com>
-- Stability   : stable
-- Portability : good
--
-- Implementation of XSalsa20 algorithm
-- <https://cr.yp.to/snuffle/xsalsa-20081128.pdf>
-- Based on the Salsa20 algorithm with 256 bit key extended with 192 bit nonce
module Crypto.Cipher.XSalsa (
    initialize,
    derive,
    combine,
    generate,
    State,
) where

import Crypto.Cipher.Salsa hiding (initialize)
import Crypto.Internal.ByteArray (ByteArrayAccess)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Foreign.Ptr

-- | Initialize a new XSalsa context with the number of rounds,
-- the key and the nonce associated.
initialize
    :: (ByteArrayAccess key, ByteArrayAccess nonce)
    => Int
    -- ^ number of rounds (8,12,20)
    -> key
    -- ^ the key (256 bits)
    -> nonce
    -- ^ the nonce (192 bits)
    -> State
    -- ^ the initial XSalsa state
initialize :: forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
initialize Int
nbRounds key
key nonce
nonce
    | Int
kLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XSalsa: key length should be 256 bits"
    | Int
nonceLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
24 =
        [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XSalsa: nonce length should be 192 bits"
    | Int
nbRounds Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [Int
8, Int
12, Int
20] = [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XSalsa: rounds should be 8, 12 or 20"
    | Bool
otherwise = IO State -> State
forall a. IO a -> a
unsafeDoIO (IO State -> State) -> IO State -> State
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
stPtr <- Int -> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.alloc Int
132 ((Ptr State -> IO ()) -> IO ScrubbedBytes)
-> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
            nonce -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. nonce -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr ->
                key -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. key -> (Ptr p -> IO a) -> IO a
B.withByteArray key
key ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
keyPtr ->
                    Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
ccrypton_xsalsa_init Ptr State
stPtr Int
nbRounds Int
kLen Ptr Word8
keyPtr Int
nonceLen Ptr Word8
noncePtr
        State -> IO State
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
  where
    kLen :: Int
kLen = key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key
    nonceLen :: Int
nonceLen = nonce -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce

-- | Use an already initialized context and new nonce material to derive another
-- XSalsa context.
--
-- This allows a multi-level cascade where a first key @k1@ and nonce @n1@ is
-- used to get @HState(k1,n1)@, and this value is then used as key @k2@ to build
-- @XSalsa(k2,n2)@.  Function 'initialize' is to be called with the first 192
-- bits of @n1|n2@, and the call to @derive@ should add the remaining 128 bits.
--
-- The output context always uses the same number of rounds as the input
-- context.
derive
    :: ByteArrayAccess nonce
    => State
    -- ^ base XSalsa state
    -> nonce
    -- ^ the remainder nonce (128 bits)
    -> State
    -- ^ the new XSalsa state
derive :: forall nonce. ByteArrayAccess nonce => State -> nonce -> State
derive (State ScrubbedBytes
stPtr') nonce
nonce
    | Int
nonceLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16 = [Char] -> State
forall a. HasCallStack => [Char] -> a
error [Char]
"XSalsa: nonce length should be 128 bits"
    | Bool
otherwise = IO State -> State
forall a. IO a -> a
unsafeDoIO (IO State -> State) -> IO State -> State
forall a b. (a -> b) -> a -> b
$ do
        ScrubbedBytes
stPtr <- ScrubbedBytes -> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
stPtr' ((Ptr State -> IO ()) -> IO ScrubbedBytes)
-> (Ptr State -> IO ()) -> IO ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr State
stPtr ->
            nonce -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. nonce -> (Ptr p -> IO a) -> IO a
B.withByteArray nonce
nonce ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
noncePtr ->
                Ptr State -> Int -> Ptr Word8 -> IO ()
ccrypton_xsalsa_derive Ptr State
stPtr Int
nonceLen Ptr Word8
noncePtr
        State -> IO State
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> IO State) -> State -> IO State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> State
State ScrubbedBytes
stPtr
  where
    nonceLen :: Int
nonceLen = nonce -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length nonce
nonce

foreign import ccall "crypton_xsalsa_init"
    ccrypton_xsalsa_init
        :: Ptr State -> Int -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()

foreign import ccall "crypton_xsalsa_derive"
    ccrypton_xsalsa_derive :: Ptr State -> Int -> Ptr Word8 -> IO ()