{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UnboxedTuples #-}
module Numeric.QuoteQuot
(
quoteQuot
, quoteRem
, quoteQuotRem
, astQuot
, AST(..)
, interpretAST
) where
import Prelude
import Data.Bits
import GHC.Exts
import Language.Haskell.TH
quoteQuot :: Word -> Q (TExp (Word -> Word))
quoteQuot :: Word -> Q (TExp (Word -> Word))
quoteQuot Word
d = AST Word -> Q (TExp (Word -> Word))
go (Word -> AST Word
forall a. (Integral a, FiniteBits a) => a -> AST a
astQuot Word
d)
where
go :: AST Word -> Q (TExp (Word -> Word))
go :: AST Word -> Q (TExp (Word -> Word))
go = \case
AST Word
Arg -> [|| id ||]
Shr AST Word
x Int
k -> [|| (`shiftR` k) . $$(go x) ||]
Shl AST Word
x Int
k -> [|| (`shiftL` k) . $$(go x) ||]
MulHi AST Word
x (W# Word#
k) -> [|| (\(W# w) -> let !(# hi, _ #) = timesWord2# w k in W# hi) . $$(go x) ||]
MulLo AST Word
x Word
k -> [|| (* k) . $$(go x) ||]
Add AST Word
x AST Word
y -> [|| \w -> $$(go x) w + $$(go y) w ||]
Sub AST Word
x AST Word
y -> [|| \w -> $$(go x) w - $$(go y) w ||]
CmpGE AST Word
x (W# Word#
k) -> [|| (\(W# w) -> W# (int2Word# (w `geWord#` k))) . $$(go x) ||]
CmpLT AST Word
x (W# Word#
k) -> [|| (\(W# w) -> W# (int2Word# (w `ltWord#` k))) . $$(go x) ||]
quoteRem :: Word -> Q (TExp (Word -> Word))
quoteRem :: Word -> Q (TExp (Word -> Word))
quoteRem Word
d = [|| snd . $$(quoteQuotRem d) ||]
quoteQuotRem :: Word -> Q (TExp (Word -> (Word, Word)))
quoteQuotRem :: Word -> Q (TExp (Word -> (Word, Word)))
quoteQuotRem Word
d = [|| \w -> let q = $$(quoteQuot d) w in (q, w - d * q) ||]
data AST a
= Arg
| MulHi (AST a) a
| MulLo (AST a) a
| Add (AST a) (AST a)
| Sub (AST a) (AST a)
| Shl (AST a) Int
| Shr (AST a) Int
| CmpGE (AST a) a
| CmpLT (AST a) a
deriving (Int -> AST a -> ShowS
[AST a] -> ShowS
AST a -> String
(Int -> AST a -> ShowS)
-> (AST a -> String) -> ([AST a] -> ShowS) -> Show (AST a)
forall a. Show a => Int -> AST a -> ShowS
forall a. Show a => [AST a] -> ShowS
forall a. Show a => AST a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AST a] -> ShowS
$cshowList :: forall a. Show a => [AST a] -> ShowS
show :: AST a -> String
$cshow :: forall a. Show a => AST a -> String
showsPrec :: Int -> AST a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> AST a -> ShowS
Show)
interpretAST :: (Integral a, FiniteBits a) => AST a -> (a -> a)
interpretAST :: AST a -> a -> a
interpretAST AST a
ast a
n = AST a -> a
go AST a
ast
where
go :: AST a -> a
go = \case
AST a
Arg -> a
n
MulHi AST a
x a
k -> Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> a) -> Integer -> a
forall a b. (a -> b) -> a -> b
$ (a -> Integer
forall a. Integral a => a -> Integer
toInteger (AST a -> a
go AST a
x) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k) Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
k
MulLo AST a
x a
k -> AST a -> a
go AST a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
k
Add AST a
x AST a
y -> AST a -> a
go AST a
x a -> a -> a
forall a. Num a => a -> a -> a
+ AST a -> a
go AST a
y
Sub AST a
x AST a
y -> AST a -> a
go AST a
x a -> a -> a
forall a. Num a => a -> a -> a
- AST a -> a
go AST a
y
Shl AST a
x Int
k -> AST a -> a
go AST a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
k
Shr AST a
x Int
k -> AST a -> a
go AST a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
k
CmpGE AST a
x a
k -> if AST a -> a
go AST a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
k then a
1 else a
0
CmpLT AST a
x a
k -> if AST a -> a
go AST a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
k then a
1 else a
0
astQuot :: (Integral a, FiniteBits a) => a -> AST a
astQuot :: a -> AST a
astQuot a
k
| a -> Bool
forall a. Bits a => a -> Bool
isSigned a
k = a -> AST a
forall a. (Integral a, FiniteBits a) => a -> AST a
signedQuot a
k
| Bool
otherwise = a -> AST a
forall a. (Integral a, FiniteBits a) => a -> AST a
unsignedQuot a
k
unsignedQuot :: (Integral a, FiniteBits a) => a -> AST a
unsignedQuot :: a -> AST a
unsignedQuot a
k'
| a -> Bool
forall a. Bits a => a -> Bool
isSigned a
k
= String -> AST a
forall a. HasCallStack => String -> a
error String
"unsignedQuot works for unsigned types only"
| a
k' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
= String -> AST a
forall a. HasCallStack => String -> a
error String
"divisor must be positive"
| a
k' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a
forall a. AST a
Arg
| a
k a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr AST a
forall a. AST a
Arg Int
kZeros
| a
k' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
= AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpGE AST a
forall a. AST a
Arg a
k'
| a
k a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
shft
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)
| Bool
otherwise
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add (AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Sub AST a
forall a. AST a
Arg (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic)) Int
1) (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic)) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)
where
fbs :: Int
fbs = a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
k'
kZeros :: Int
kZeros = a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros a
k'
k :: a
k = a
k' a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
kZeros
r0 :: a
r0 = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
fbs) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k)
shft :: Int
shft = a -> Int -> Int
go a
r0 Int
0
magic :: a
magic = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
shft)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)
go :: a -> Int -> Int
go a
r Int
s
| (a
k a -> a -> a
forall a. Num a => a -> a -> a
- a
r) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
s = Int
s
| Bool
otherwise = a -> Int -> Int
go (a
r a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
1 a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
k) (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
signedQuot :: (Integral a, FiniteBits a) => a -> AST a
signedQuot :: a -> AST a
signedQuot a
k'
| Bool -> Bool
not (a -> Bool
forall a. Bits a => a -> Bool
isSigned a
k)
= String -> AST a
forall a. HasCallStack => String -> a
error String
"signedQuot works for signed types only"
| a
k' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0
= String -> AST a
forall a. HasCallStack => String -> a
error String
"divisor must be positive"
| a
k' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a
forall a. AST a
Arg
| a
k a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add AST a
forall a. AST a
Arg (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulLo (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg a
0) (a
k' a -> a -> a
forall a. Num a => a -> a -> a
- a
1))) Int
kZeros
| a
k' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)
= AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Sub (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpGE AST a
forall a. AST a
Arg a
k') (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
k'))
| a
magic a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0
= AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add (AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)) (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg a
0)
| Bool
otherwise
= AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add (AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add AST a
forall a. AST a
Arg (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic)) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)) (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg a
0)
where
fbs :: Int
fbs = a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
k'
kZeros :: Int
kZeros = a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros a
k'
k :: a
k = a
k' a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
kZeros
r0 :: a
r0 = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
fbs) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k)
shft :: Int
shft = a -> Int -> Int
go a
r0 Int
0
magic :: a
magic = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
shft)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)
go :: a -> Int -> Int
go a
r Int
s
| (a
k a -> a -> a
forall a. Num a => a -> a -> a
- a
r) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) = Int
s
| Bool
otherwise = a -> Int -> Int
go (a
r a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
1 a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
k) (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
shr :: AST a -> Int -> AST a
shr :: AST a -> Int -> AST a
shr AST a
x Int
0 = AST a
x
shr AST a
x Int
k = AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
Shr AST a
x Int
k