{-# OPTIONS_GHC -O2 -fglasgow-exts #-}

{- |
  Module      :  Encoding.UTF8.Light
  Copyright   :  (c) Matt Morrow 2008
  License     :  BSD3

  Maintainer  :  Matt Morrow <mjm2002@gmail.com>
  Stability   :  provisional
  Portability :  portable
-}

module Encoding.UTF8.Light (
    UTF8(..)
  , lenUTF8
  , lenUTF16
  , countUTF8
  , decodeUTF8
  , encodeUTF8
  , encodeUTF8'
  , withUTF8
  , putUTF8
  , putUTF8Ln
  , hPutUTF8
  , hPutUTF8Ln
  , readUTF8File
  , writeUTF8File
  , appendUTF8File
  , hGetUTF8Line
  , hGetUTF8Contents
  , hGetUTF8
  , hGetUTF8NonBlocking
  , flipUTF8
  , unflipUTF8
  , flipTab
  , unflipTab
  , showHex
  , toBits
  , fromBits
) where

import Data.Bits
import Data.Word(Word8,Word16,Word32)
import Data.List(foldl')
import Data.Monoid(Monoid(..))
import Data.ByteString(ByteString)
import qualified Data.ByteString as B
import Data.ByteString.Internal(c2w,w2c)
import Data.ByteString.Unsafe
import Data.Char(chr,ord)
import System.IO(Handle)

-- | For convenience
fi :: (Num b, Integral a) => a -> b
fi = fromIntegral

-- | Instances: @ByteString@, @String@, @[Int]@
class UTF8 a where
  encode :: a -> ByteString
  decode :: ByteString -> a

instance UTF8 ByteString where
  encode = id
  decode = id

instance UTF8 [Int] where
  encode = encodeUTF8
  decode = decodeUTF8

instance UTF8 String where
  encode = encode . fmap ord
  decode = fmap chr . decode

withUTF8 :: (UTF8 a) => a -> (ByteString -> b) -> b
withUTF8 a k = k (encode a)

putUTF8 :: (UTF8 a) => a -> IO ()
putUTF8 = flip withUTF8 B.putStr

putUTF8Ln :: (UTF8 a) => a -> IO ()
putUTF8Ln = flip withUTF8 B.putStrLn

hPutUTF8 :: (UTF8 a) => Handle -> a -> IO ()
hPutUTF8 h = flip withUTF8 (B.hPut h)

hPutUTF8Ln :: (UTF8 a) => Handle -> a -> IO ()
hPutUTF8Ln h = flip withUTF8 (B.hPutStrLn h)

readUTF8File :: (UTF8 a) => FilePath -> IO a
readUTF8File = (return . decode =<<) . B.readFile

writeUTF8File :: (UTF8 a) => FilePath -> a -> IO ()
writeUTF8File p = B.writeFile p . encode

appendUTF8File :: (UTF8 a) => FilePath -> a -> IO ()
appendUTF8File p = B.appendFile p . encode

hGetUTF8Line :: (UTF8 a) => Handle -> IO a
hGetUTF8Line = (return . decode =<<) . B.hGetLine

hGetUTF8Contents :: (UTF8 a) => Handle -> IO a
hGetUTF8Contents = (return . decode =<<) . B.hGetContents

-- | Be careful that you're sure you're not
--  chopping a UTF8 char in two!
hGetUTF8 :: (UTF8 a) => Handle -> Int -> IO a
hGetUTF8 h = (return . decode =<<) . B.hGet h

-- | Same warning as for @hGetUTF8@
hGetUTF8NonBlocking :: (UTF8 a) => Handle -> Int -> IO a
hGetUTF8NonBlocking h = (return . decode =<<) . B.hGetNonBlocking h

-- | Length in Word8s
-- XXX: use (<) instead of shiftR and (==)
lenUTF8 :: Word8 -> Int
lenUTF8 w8
  | w8`shiftR`7==0x00 = 1
  | w8`shiftR`5==0x06 = 2
  | w8`shiftR`4==0x0e = 3
  | w8`shiftR`3==0x1e = 4
  | otherwise         = 0

-- | Length in Word16s
lenUTF16 :: Word16 -> Int
lenUTF16 w16
  | w16`shiftR`10==0x36 = 2
  | w16`shiftR`10==0x37 = 0
  | otherwise           = 1

-- | Lengths in Word8s
countUTF8 :: ByteString -> [Int]
countUTF8 s = go 0 (B.length s) s
  where go :: Int -> Int -> ByteString -> [Int]
        go i len s
          | len <= i  = []
          | otherwise = let n = lenUTF8 (unsafeIndex s i)
                            i' = i+n
                        in if n==0
                            then []
                            else i' `seq` n : go i' len s

encodeUTF8 :: [Int] -> ByteString
encodeUTF8 = B.pack . concat . encodeUTF8'

encodeUTF8' :: [Int] -> [[Word8]]
encodeUTF8' [] = []
encodeUTF8' (x:xs)
  | x.&.0xffffff80==0 =
      [fi x] : encodeUTF8' xs
  | x.&.0xfffff800==0 =
      [ fi(x`shiftR`6.|.0xc0)
      , fi(x.&.0x3f.|.0x80)
      ] : encodeUTF8' xs
  | x.&.0xffff0000==0 =
      [ fi(x`shiftR`12.|.0xe0)
      , fi(x`shiftR`6.&.0x3f.|.0x80)
      , fi(x.&.0x3f.|.0x80)
      ] : encodeUTF8' xs
  | x.&.0xffe00000==0 =
      [ fi(x`shiftR`18.|.0xf0)
      , fi(x`shiftR`12.&.0x3f.|.0x80)
      , fi(x`shiftR`6.&.0x3f.|.0x80)
      , fi(x.&.0x3f.|.0x80)
      ] : encodeUTF8' xs
  | otherwise = [] : encodeUTF8' xs

decodeUTF8 :: ByteString -> [Int]
decodeUTF8 s = go 0 (B.length s) s
  where go :: Int -> Int -> ByteString -> [Int]
        go i len s
          | len <= i  = []
          | otherwise =
              let c1 = unsafeIndex s i
              in case lenUTF8 c1 of
                  0 -> []
                  1 ->  let i' = i+1
                        in i'`seq`
                            fi c1 : go i' len s
                  2 -> if len <= i+1
                        then []
                        else
                          let c2 = unsafeIndex s (i+1)
                              i' = i+2
                          in i'`seq`
                              fi(c1.&.0x1f)`shiftL`6
                                `xor`fi(c2.&.0x3f)
                                  : go i' len s
                  3 -> if len <= i+2
                        then []
                        else
                          let c2 = unsafeIndex s (i+1)
                              c3 = unsafeIndex s (i+2)
                              i' = i+3
                          in i'`seq`
                              fi(c1.&.0x1f)`shiftL`12
                                `xor`fi(c2.&.0x3f)`shiftL`6
                                  `xor`fi(c3.&.0x3f)
                                    : go i' len s
                  4 -> if len <= i+3
                        then []
                        else
                          let c2 = unsafeIndex s (i+1)
                              c3 = unsafeIndex s (i+2)
                              c4 = unsafeIndex s (i+3)
                              i' = i+4
                          in i'`seq`
                              fi(c1.&.0x1f)`shiftL`18
                                `xor`fi(c2.&.0x3f)`shiftL`12
                                  `xor`fi(c3.&.0x3f)`shiftL`6
                                    `xor`fi(c4.&.0x3f)
                                      : go i' len s

-----------------------------------------------------------------------------

-- misc debug stuff

toBits :: Word8 -> [Word8]
toBits w8 = fmap ((.&.0x01) . (w8`shiftR`)) [7,6,5,4,3,2,1,0]

fromBits :: [Word8] -> Word8
fromBits = foldl' (\a (n,b) -> a.|.b`shiftL`n) 0
            . reverse . zip [0..7] . reverse

hexTab :: ByteString
hexTab = B.pack . fmap c2w $
  "0123456789abcdef"

showHex :: Int -> String
showHex i = ("0x"++)
  . flip fmap [28,24,20,16,12,8,4,0] $ \n ->
    w2c (unsafeIndex hexTab (i`shiftR`n.&.0xf))

-----------------------------------------------------------------------------

-- now, for fun...

{- |
> ghci> putUTF8Ln $ flipUTF8 "[?np_bs!]"
> [¡sq‾bu¿]
-}
flipUTF8 :: (UTF8 a) => a -> a
flipUTF8 = decode . flipString flipTab . encode

{- |
> ghci> putUTF8Ln $ (unflipUTF8 . flipUTF8) "[?np_bs!]"
> [?np_bs!]
-}
unflipUTF8 :: (UTF8 a) => a -> a
unflipUTF8 = decode . flipString unflipTab . encode

-- | Omits chars it doesn't know how to flip. Possibly
--  it's more desirable to just be id on such chars?
flipString :: [(Int,Int)] -> ByteString -> ByteString
flipString tab = encode
                  . reverse
                    . fmap (maybe ' ' chr
                              . flip lookup tab)
                      . decode

unflipTab :: [(Int,Int)]
unflipTab = fmap (uncurry(flip(,))) flipTab

flipTab :: [(Int,Int)]
flipTab = fmap (\(a,b)->(ord a,b))
  [('a', 0x250)
  ,('b', ord 'q')
  ,('c', 0x254)
  ,('d', ord 'p')
  ,('e', 0x1dd)
  ,('f', 0x25f)
  ,('g', 0x183)
  ,('h', 0x265)
  ,('i', 0x131)
  ,('j', 0x27e)
  ,('k', 0x29e)
  ,('l', ord 'l')
  ,('m', 0x26f)
  ,('n', ord 'u')
  ,('o', ord 'o')
  ,('p', ord 'b')
  ,('q', ord 'd')
  ,('r', 0x279)
  ,('s', ord 's')
  ,('t', 0x287)
  ,('u', ord 'n')
  ,('v', 0x28c)
  ,('w', 0x28d)
  ,('x', ord 'x')
  ,('y', 0x28e)
  ,('z', ord 'z')
  ,('.', 0x2d9)
  ,('[', ord ']')
  ,(']', ord '[')
  ,('{', ord '}')
  ,('}', ord '{')
  ,('<', ord '>')
  ,('>', ord '<')
  ,('?', 0xbf)
  ,('!', 0xa1)
  ,('\'', ord ',')
  ,('_', 0x203e)
  ,(';', 0x061b)
  ]

{-
ghci> mapM_ print . zip (fmap show [0..9] ++ fmap (:[]) ['a'..'f']) . fmap (drop 4 . toBits) $ [0..15]
("0",[0,0,0,0])
("1",[0,0,0,1])
("2",[0,0,1,0])
("3",[0,0,1,1])
("4",[0,1,0,0])
("5",[0,1,0,1])
("6",[0,1,1,0])
("7",[0,1,1,1])
("8",[1,0,0,0])
("9",[1,0,0,1])
("a",[1,0,1,0])
("b",[1,0,1,1])
("c",[1,1,0,0])
("d",[1,1,0,1])
("e",[1,1,1,0])
("f",[1,1,1,1])

class (Num a) => Bits a where
  (.&.) :: a -> a -> a
  (.|.) :: a -> a -> a
  xor :: a -> a -> a
  complement :: a -> a
  shift :: a -> Int -> a
  rotate :: a -> Int -> a
  bit :: Int -> a
  setBit :: a -> Int -> a
  clearBit :: a -> Int -> a
  complementBit :: a -> Int -> a
  testBit :: a -> Int -> Bool
  bitSize :: a -> Int
  isSigned :: a -> Bool
  shiftL :: a -> Int -> a
  shiftR :: a -> Int -> a
  rotateL :: a -> Int -> a
  rotateR :: a -> Int -> a

uncheckedIShiftL#   :: Int# -> Int# -> Int#
uncheckedIShiftRA#  :: Int# -> Int# -> Int#
uncheckedIShiftRL#  :: Int# -> Int# -> Int#
uncheckedShiftL#    :: Word# -> Int# -> Word#
uncheckedShiftRL#   :: Word# -> Int# -> Word#
-}