{-# LANGUAGE ScopedTypeVariables #-}
module Math.NumberTheory.ArithmeticFunctions.Standard
  ( 
    multiplicative
  , divisors, divisorsA
  , divisorsList, divisorsListA
  , divisorsSmall, divisorsSmallA
  , divisorCount, tau, tauA
  , sigma, sigmaA
  , totient, totientA
  , jordan, jordanA
  , ramanujan, ramanujanA
  , moebius, moebiusA, Moebius(..), runMoebius
  , liouville, liouvilleA
    
  , additive
  , smallOmega, smallOmegaA
  , bigOmega, bigOmegaA
    
  , carmichael, carmichaelA
  , expMangoldt, expMangoldtA
  , isNFree, isNFreeA, nFrees, nFreesBlock
  ) where
import Data.Coerce
import Data.IntSet (IntSet)
import qualified Data.IntSet as IS
import Data.Set (Set)
import qualified Data.Set as S
import Data.Semigroup
import Math.NumberTheory.ArithmeticFunctions.Class
import Math.NumberTheory.ArithmeticFunctions.Moebius
import Math.NumberTheory.ArithmeticFunctions.NFreedom (nFrees, nFreesBlock)
import Math.NumberTheory.Primes
import Math.NumberTheory.Utils.FromIntegral
import Numeric.Natural
multiplicative :: Num a => (Prime n -> Word -> a) -> ArithmeticFunction n a
multiplicative f = ArithmeticFunction ((Product .) . f) getProduct
divisors :: (UniqueFactorisation n, Ord n) => n -> Set n
divisors = runFunction divisorsA
{-# SPECIALIZE divisors :: Natural -> Set Natural #-}
{-# SPECIALIZE divisors :: Integer -> Set Integer #-}
divisorsA :: (UniqueFactorisation n, Ord n) => ArithmeticFunction n (Set n)
divisorsA = ArithmeticFunction (\p -> SetProduct . divisorsHelper (unPrime p)) (S.insert 1 . getSetProduct)
divisorsHelper :: Num n => n -> Word -> Set n
divisorsHelper _ 0 = S.empty
divisorsHelper p 1 = S.singleton p
divisorsHelper p a = S.fromDistinctAscList $ p : p * p : map (p ^) [3 .. wordToInt a]
{-# INLINE divisorsHelper #-}
divisorsList :: UniqueFactorisation n => n -> [n]
divisorsList = runFunction divisorsListA
divisorsListA :: UniqueFactorisation n => ArithmeticFunction n [n]
divisorsListA = ArithmeticFunction (\p -> ListProduct . divisorsListHelper (unPrime p)) ((1 :) . getListProduct)
divisorsListHelper :: Num n => n -> Word -> [n]
divisorsListHelper _ 0 = []
divisorsListHelper p 1 = [p]
divisorsListHelper p a = p : p * p : map (p ^) [3 .. wordToInt a]
{-# INLINE divisorsListHelper #-}
divisorsSmall :: Int -> IntSet
divisorsSmall = runFunction divisorsSmallA
divisorsSmallA :: ArithmeticFunction Int IntSet
divisorsSmallA = ArithmeticFunction (\p -> IntSetProduct . divisorsHelperSmall (unPrime p)) (IS.insert 1 . getIntSetProduct)
divisorsHelperSmall :: Int -> Word -> IntSet
divisorsHelperSmall _ 0 = IS.empty
divisorsHelperSmall p 1 = IS.singleton p
divisorsHelperSmall p a = IS.fromDistinctAscList $ p : p * p : map (p ^) [3 .. wordToInt a]
{-# INLINE divisorsHelperSmall #-}
divisorCount :: (UniqueFactorisation n, Num a) => n -> a
divisorCount = tau
tau :: (UniqueFactorisation n, Num a) => n -> a
tau = runFunction tauA
tauA :: Num a => ArithmeticFunction n a
tauA = multiplicative $ const (fromIntegral . succ)
sigma :: (UniqueFactorisation n, Integral n) => Word -> n -> n
sigma = runFunction . sigmaA
sigmaA :: (UniqueFactorisation n, Integral n) => Word -> ArithmeticFunction n n
sigmaA 0 = tauA
sigmaA 1 = multiplicative $ sigmaHelper . unPrime
sigmaA a = multiplicative $ sigmaHelper . (^ wordToInt a) . unPrime
sigmaHelper :: Integral n => n -> Word -> n
sigmaHelper pa 1 = pa + 1
sigmaHelper pa 2 = pa * pa + pa + 1
sigmaHelper pa k = (pa ^ wordToInt (k + 1) - 1) `quot` (pa - 1)
{-# INLINE sigmaHelper #-}
totient :: UniqueFactorisation n => n -> n
totient = runFunction totientA
totientA :: UniqueFactorisation n => ArithmeticFunction n n
totientA = multiplicative $ jordanHelper . unPrime
jordan :: UniqueFactorisation n => Word -> n -> n
jordan = runFunction . jordanA
jordanA :: UniqueFactorisation n => Word -> ArithmeticFunction n n
jordanA 0 = multiplicative $ \_ _ -> 0
jordanA 1 = totientA
jordanA a = multiplicative $ jordanHelper . (^ wordToInt a) . unPrime
jordanHelper :: Num n => n -> Word -> n
jordanHelper pa 1 = pa - 1
jordanHelper pa 2 = (pa - 1) * pa
jordanHelper pa k = (pa - 1) * pa ^ wordToInt (k - 1)
{-# INLINE jordanHelper #-}
ramanujan :: Integer -> Integer
ramanujan = runFunction ramanujanA
ramanujanA :: ArithmeticFunction Integer Integer
ramanujanA = multiplicative $ ramanujanHelper . unPrime
ramanujanHelper :: Integer -> Word -> Integer
ramanujanHelper _ 0 = 1
ramanujanHelper 2 1 = -24
ramanujanHelper p 1 = (65 * sigmaHelper (p ^ (11 :: Int)) 1 + 691 * sigmaHelper (p ^ (5 :: Int)) 1 - 691 * 252 * 2 * sum [sigma 5 k * sigma 5 (p-k) | k <- [1..(p `quot` 2)]]) `quot` 756
ramanujanHelper p k = sum $ zipWith3 (\a b c -> a * b * c) paPowers tpPowers binomials
  where pa = p ^ (11 :: Int)
        tp = ramanujanHelper p 1
        paPowers = iterate (* (-pa)) 1
        binomials = scanl (\acc j -> acc * (k' - 2 * j) * (k' - 2 * j - 1) `quot` (k' - j) `quot` (j + 1)) 1 [0 .. k' `quot` 2 - 1]
        k' = fromIntegral k
        tpPowers = reverse $ take (length binomials) $ iterate (* tp^(2::Int)) (if even k then 1 else tp)
{-# INLINE ramanujanHelper #-}
moebius :: UniqueFactorisation n => n -> Moebius
moebius = runFunction moebiusA
moebiusA :: ArithmeticFunction n Moebius
moebiusA = ArithmeticFunction (const f) id
  where
    f 1 = MoebiusN
    f 0 = MoebiusP
    f _ = MoebiusZ
liouville :: (UniqueFactorisation n, Num a) => n -> a
liouville = runFunction liouvilleA
liouvilleA :: Num a => ArithmeticFunction n a
liouvilleA = ArithmeticFunction (const $ Xor . odd) runXor
carmichael :: (UniqueFactorisation n, Integral n) => n -> n
carmichael = runFunction carmichaelA
{-# SPECIALIZE carmichael :: Int     -> Int #-}
{-# SPECIALIZE carmichael :: Word    -> Word #-}
{-# SPECIALIZE carmichael :: Integer -> Integer #-}
{-# SPECIALIZE carmichael :: Natural -> Natural #-}
carmichaelA :: (UniqueFactorisation n, Integral n) => ArithmeticFunction n n
carmichaelA = ArithmeticFunction (\p -> LCM . f (unPrime p)) getLCM
  where
    f 2 1 = 1
    f 2 2 = 2
    f 2 k = 2 ^ wordToInt (k - 2)
    f p 1 = p - 1
    f p 2 = (p - 1) * p
    f p k = (p - 1) * p ^ wordToInt (k - 1)
additive :: Num a => (Prime n -> Word -> a) -> ArithmeticFunction n a
additive f = ArithmeticFunction ((Sum .) . f) getSum
smallOmega :: (UniqueFactorisation n, Num a) => n -> a
smallOmega = runFunction smallOmegaA
smallOmegaA :: Num a => ArithmeticFunction n a
smallOmegaA = additive (\_ _ -> 1)
bigOmega :: UniqueFactorisation n => n -> Word
bigOmega = runFunction bigOmegaA
bigOmegaA :: ArithmeticFunction n Word
bigOmegaA = additive $ const id
expMangoldt :: UniqueFactorisation n => n -> n
expMangoldt = runFunction expMangoldtA
expMangoldtA :: UniqueFactorisation n => ArithmeticFunction n n
expMangoldtA = ArithmeticFunction (const . MangoldtOne . unPrime) runMangoldt
data Mangoldt a
  = MangoldtZero
  | MangoldtOne a
  | MangoldtMany
runMangoldt :: Num a => Mangoldt a -> a
runMangoldt m = case m of
  MangoldtZero  -> 1
  MangoldtOne a -> a
  MangoldtMany  -> 1
instance Semigroup (Mangoldt a) where
  MangoldtZero <> a = a
  a <> MangoldtZero = a
  _ <> _ = MangoldtMany
instance Monoid (Mangoldt a) where
  mempty  = MangoldtZero
  mappend = (<>)
isNFree :: UniqueFactorisation n => Word -> n -> Bool
isNFree n = runFunction (isNFreeA n)
isNFreeA :: Word -> ArithmeticFunction n Bool
isNFreeA n = ArithmeticFunction (\_ pow -> All $ pow < n) getAll
newtype LCM a = LCM { getLCM :: a }
instance Integral a => Semigroup (LCM a) where
  (<>) = coerce (lcm :: a -> a -> a)
instance Integral a => Monoid (LCM a) where
  mempty  = LCM 1
  mappend = (<>)
newtype Xor = Xor { _getXor :: Bool }
runXor :: Num a => Xor -> a
runXor m = case m of
  Xor False ->  1
  Xor True  -> -1
instance Semigroup Xor where
   (<>) = coerce ((/=) :: Bool -> Bool -> Bool)
instance Monoid Xor where
  mempty  = Xor False
  mappend = (<>)
newtype SetProduct a = SetProduct { getSetProduct :: Set a }
instance (Num a, Ord a) => Semigroup (SetProduct a) where
  SetProduct s1 <> SetProduct s2 = SetProduct $ s1 <> s2 <> foldMap (\n -> S.mapMonotonic (* n) s2) s1
instance (Num a, Ord a) => Monoid (SetProduct a) where
  mempty  = SetProduct mempty
  mappend = (<>)
newtype ListProduct a = ListProduct { getListProduct :: [a] }
instance Num a => Semigroup (ListProduct a) where
  ListProduct s1 <> ListProduct s2 = ListProduct $ s1 <> s2 <> foldMap (\n -> map (* n) s2) s1
instance Num a => Monoid (ListProduct a) where
  mempty  = ListProduct mempty
  mappend = (<>)
newtype IntSetProduct = IntSetProduct { getIntSetProduct :: IntSet }
instance Semigroup IntSetProduct where
  IntSetProduct s1 <> IntSetProduct s2 = IntSetProduct $ IS.unions $ s1 : s2 : map (\n -> IS.map (* n) s2) (IS.toAscList s1)
instance Monoid IntSetProduct where
  mempty  = IntSetProduct mempty
  mappend = (<>)