-- |
-- Module      : Machine
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Architecture-dependent utilities, to read/write unaligned machine words
-- in little-endian order
--
{-# LANGUAGE CPP #-}
module Machine
    ( WordM, WordLE, assertMultM, fromLE, toLE, wordBits, wordBytes
    ) where

#include "MachDeps.h"

-- Taken from `bytestring`, a list of architectures known to accept
-- unaligned loads and stores
#if defined(i386_HOST_ARCH) || defined(x86_64_HOST_ARCH)          \
    || ((defined(arm_HOST_ARCH) || defined(aarch64_HOST_ARCH))    \
         && defined(__ARM_FEATURE_UNALIGNED))                     \
    || defined(powerpc_HOST_ARCH) || defined(powerpc64_HOST_ARCH) \
    || defined(powerpc64le_HOST_ARCH)
#define MLKEM_ALLOW_UNALIGNED_OP 1

-- Little-endian conversion in `basement` and `memory` is avoided at compile
-- time only for AMD/Intel, here we will short circuit on ARM too
#if (defined(arm_HOST_ARCH) || defined(aarch64_HOST_ARCH)) \
    && !defined(WORDS_BIGENDIAN)
#define MLKEM_FORCE_LITTLE_ENDIAN_ARCH 1
#endif

#endif

import Control.Exception (assert)

#ifdef MLKEM_ALLOW_UNALIGNED_OP
import qualified Data.Memory.Endian as B
#endif

import Data.Bits
import Data.Word

#ifdef MLKEM_ALLOW_UNALIGNED_OP

-- our preferred word size
#if WORD_SIZE_IN_BITS == 64
type WordM = Word64
#else
type WordM = Word32
#endif

type WordLE = B.LE WordM

fromLE :: WordLE -> WordM
#ifdef MLKEM_FORCE_LITTLE_ENDIAN_ARCH
fromLE = B.unLE  -- unwrap constructor with no byte swapping
#else
fromLE = B.fromLE  -- byte swap if necessary
#endif

toLE :: WordM -> WordLE
#ifdef MLKEM_FORCE_LITTLE_ENDIAN_ARCH
toLE = B.LE  -- wrap constructor with no byte swapping
#else
toLE = B.toLE  -- byte swap if necessary
#endif

#else

-- unaligned memory access is not allowed so we fallback to one byte at a time
-- and endianness does not matter

type WordM = Word8
type WordLE = WordM

fromLE :: WordLE -> WordM
fromLE = id

toLE :: WordM -> WordLE
toLE = id

#endif

wordBits :: Int
wordBits = finiteBitSize (0 :: WordM)

wordBytes :: Int
wordBytes = div wordBits 8

assertMultM :: Int -> a -> a
assertMultM n = assert (n .&. mask == 0)
  where mask = wordBytes - 1
