{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Math.NumberTheory.MoebiusInversion
( generalInversion
, totientSum
) where
import Control.Monad
import Control.Monad.ST
import Data.Proxy
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
import Math.NumberTheory.Powers.Squares
import Math.NumberTheory.Utils.FromIntegral
totientSum
:: (Integral t, G.Vector v t)
=> Proxy v
-> Word
-> t
totientSum _ 0 = 0
totientSum proxy n = generalInversion proxy (triangle . fromIntegral) n
where
triangle k = (k * (k + 1)) `quot` 2
generalInversion
:: (Num t, G.Vector v t)
=> Proxy v
-> (Word -> t)
-> Word
-> t
generalInversion proxy fun n = case n of
0 ->error "Möbius inversion only defined on positive domain"
1 -> fun 1
2 -> fun 2 - fun 1
3 -> fun 3 - 2*fun 1
_ -> runST (fastInvertST proxy (fun . intToWord) (wordToInt n))
fastInvertST
:: forall s t v.
(Num t, G.Vector v t)
=> Proxy v
-> (Int -> t)
-> Int
-> ST s t
fastInvertST _ fun n = do
let !k0 = integerSquareRoot (n `quot` 2)
!mk0 = n `quot` (2*k0+1)
kmax a m = (a `quot` m - 1) `quot` 2
small <- MG.unsafeNew (mk0 + 1) :: ST s (G.Mutable v s t)
MG.unsafeWrite small 0 0
MG.unsafeWrite small 1 $! (fun 1)
when (mk0 >= 2) $
MG.unsafeWrite small 2 $! (fun 2 - fun 1)
let calcit :: Int -> Int -> Int -> ST s (Int, Int)
calcit switch change i
| mk0 < i = return (switch,change)
| i == change = calcit (switch+1) (change + 4*switch+6) i
| otherwise = do
let mloop !acc k !m
| k < switch = kloop acc k
| otherwise = do
val <- MG.unsafeRead small m
let nxtk = kmax i (m+1)
mloop (acc - fromIntegral (k-nxtk)*val) nxtk (m+1)
kloop !acc k
| k == 0 = do
MG.unsafeWrite small i $! acc
calcit switch change (i+1)
| otherwise = do
val <- MG.unsafeRead small (i `quot` (2*k+1))
kloop (acc-val) (k-1)
mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1
(sw, ch) <- calcit 1 8 3
large <- MG.unsafeNew k0 :: ST s (G.Mutable v s t)
let calcbig :: Int -> Int -> Int -> ST s (G.Mutable v s t)
calcbig switch change j
| j == 0 = return large
| (2*j-1)*change <= n = calcbig (switch+1) (change + 4*switch+6) j
| otherwise = do
let i = n `quot` (2*j-1)
mloop !acc k m
| k < switch = kloop acc k
| otherwise = do
val <- MG.unsafeRead small m
let nxtk = kmax i (m+1)
mloop (acc - fromIntegral (k-nxtk)*val) nxtk (m+1)
kloop !acc k
| k == 0 = do
MG.unsafeWrite large (j-1) $! acc
calcbig switch change (j-1)
| otherwise = do
let m = i `quot` (2*k+1)
val <- if m <= mk0
then MG.unsafeRead small m
else MG.unsafeRead large (k*(2*j-1)+j-1)
kloop (acc-val) (k-1)
mloop (fun i - fun (i `quot` 2)) ((i-1) `quot` 2) 1
mvec <- calcbig sw ch k0
MG.unsafeRead mvec 0