{-# OPTIONS_GHC -Wno-orphans #-}

module Data.Function.FastMemo.Char () where

import Data.Bits (complement, countLeadingZeros)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.UTF8 as UTF8
import Data.Function.FastMemo.Class (Memoizable (..))
import Data.Function.FastMemo.Util (memoizeFixedLen)
import Data.Function.FastMemo.Word ()
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NonEmpty
import Data.Word (Word8)

-- We want ASCII Chars to require only a single Vector lookup, so let's encode as UTF-8
instance Memoizable Char where
  memoize :: forall b. (Char -> b) -> Char -> b
memoize Char -> b
f = forall a b. Memoizable a => (a -> b) -> a -> b
memoize (Char -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodePoint -> Char
codePointToChar) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> CodePoint
charToCodePoint

newtype CodePoint = CodePoint {CodePoint -> NonEmpty Word8
getCodePoint :: NonEmpty Word8}

-- In UTF-8, the first byte of a codepoint tells us how many more bytes that codepoint contains.
-- We can use this fact to reduce lookups.
instance Memoizable CodePoint where
  memoize :: forall b. (CodePoint -> b) -> CodePoint -> b
memoize CodePoint -> b
f =
    let f' :: Word8 -> [Word8] -> b
f' = forall a b. Memoizable a => (a -> b) -> a -> b
memoize (\Word8
w -> forall a b.
(HasCallStack, Memoizable a) =>
Int -> ([a] -> b) -> [a] -> b
memoizeFixedLen (Word8 -> Int
extraBytes Word8
w) (CodePoint -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty Word8 -> CodePoint
CodePoint forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8
w forall a. a -> [a] -> NonEmpty a
:|)))
     in \(CodePoint (Word8
w :| [Word8]
ws)) -> Word8 -> [Word8] -> b
f' Word8
w [Word8]
ws

extraBytes :: Word8 -> Int
extraBytes :: Word8 -> Int
extraBytes Word8
x = case Word8 -> Int
countLeadingOnes Word8
x of
  Int
0 -> Int
0
  Int
n -> Int
n forall a. Num a => a -> a -> a
- Int
1

countLeadingOnes :: Word8 -> Int
countLeadingOnes :: Word8 -> Int
countLeadingOnes = forall b. FiniteBits b => b -> Int
countLeadingZeros forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Bits a => a -> a
complement

charToCodePoint :: Char -> CodePoint
charToCodePoint :: Char -> CodePoint
charToCodePoint = NonEmpty Word8 -> CodePoint
CodePoint forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> NonEmpty a
NonEmpty.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
ByteString.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
UTF8.fromString forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
: [])

codePointToChar :: CodePoint -> Char
codePointToChar :: CodePoint -> Char
codePointToChar = forall a. [a] -> a
head forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> String
UTF8.toString forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Word8] -> ByteString
ByteString.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. NonEmpty a -> [a]
NonEmpty.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. CodePoint -> NonEmpty Word8
getCodePoint