-- |
-- Module:      Math.NumberTheory.ArithmeticFunctions.Class
-- Copyright:   (c) 2016 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Generic type for arithmetic functions over arbitrary unique
-- factorisation domains.
--

{-# LANGUAGE CPP                 #-}
{-# LANGUAGE GADTs               #-}

module Math.NumberTheory.ArithmeticFunctions.Class
  ( ArithmeticFunction(..)
  , runFunction
  , runFunctionOnFactors
  ) where

import Control.Applicative
#if __GLASGOW_HASKELL__ < 803
import Data.Semigroup
#endif

import Math.NumberTheory.Primes

-- | A typical arithmetic function operates on the canonical factorisation of
-- a number into prime's powers and consists of two rules. The first one
-- determines the values of the function on the powers of primes. The second
-- one determines how to combine these values into final result.
--
-- In the following definition the first argument is the function on prime's
-- powers, the monoid instance determines a rule of combination (typically
-- 'Data.Semigroup.Product' or 'Data.Semigroup.Sum'), and the second argument is convenient for unwrapping
-- (typically, 'Data.Semigroup.getProduct' or 'Data.Semigroup.getSum').
data ArithmeticFunction n a where
  ArithmeticFunction
    :: Monoid m
    => (Prime n -> Word -> m)
    -> (m -> a)
    -> ArithmeticFunction n a

-- | Convert to a function. The value on 0 is undefined.
runFunction :: UniqueFactorisation n => ArithmeticFunction n a -> n -> a
runFunction :: ArithmeticFunction n a -> n -> a
runFunction ArithmeticFunction n a
f = ArithmeticFunction n a -> [(Prime n, Word)] -> a
forall n a. ArithmeticFunction n a -> [(Prime n, Word)] -> a
runFunctionOnFactors ArithmeticFunction n a
f ([(Prime n, Word)] -> a) -> (n -> [(Prime n, Word)]) -> n -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. n -> [(Prime n, Word)]
forall a. UniqueFactorisation a => a -> [(Prime a, Word)]
factorise

-- | Convert to a function on prime factorisation.
runFunctionOnFactors :: ArithmeticFunction n a -> [(Prime n, Word)] -> a
runFunctionOnFactors :: ArithmeticFunction n a -> [(Prime n, Word)] -> a
runFunctionOnFactors (ArithmeticFunction Prime n -> Word -> m
f m -> a
g)
  = m -> a
g
  (m -> a) -> ([(Prime n, Word)] -> m) -> [(Prime n, Word)] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [m] -> m
forall a. Monoid a => [a] -> a
mconcat
  ([m] -> m) -> ([(Prime n, Word)] -> [m]) -> [(Prime n, Word)] -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Prime n, Word) -> m) -> [(Prime n, Word)] -> [m]
forall a b. (a -> b) -> [a] -> [b]
map ((Prime n -> Word -> m) -> (Prime n, Word) -> m
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Prime n -> Word -> m
f)

instance Functor (ArithmeticFunction n) where
  fmap :: (a -> b) -> ArithmeticFunction n a -> ArithmeticFunction n b
fmap a -> b
f (ArithmeticFunction Prime n -> Word -> m
g m -> a
h) = (Prime n -> Word -> m) -> (m -> b) -> ArithmeticFunction n b
forall m n a.
Monoid m =>
(Prime n -> Word -> m) -> (m -> a) -> ArithmeticFunction n a
ArithmeticFunction Prime n -> Word -> m
g (a -> b
f (a -> b) -> (m -> a) -> m -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m -> a
h)

instance Applicative (ArithmeticFunction n) where
  pure :: a -> ArithmeticFunction n a
pure a
x
    = (Prime n -> Word -> ()) -> (() -> a) -> ArithmeticFunction n a
forall m n a.
Monoid m =>
(Prime n -> Word -> m) -> (m -> a) -> ArithmeticFunction n a
ArithmeticFunction (\Prime n
_ Word
_ -> ()) (a -> () -> a
forall a b. a -> b -> a
const a
x)
  (ArithmeticFunction Prime n -> Word -> m
f1 m -> a -> b
g1) <*> :: ArithmeticFunction n (a -> b)
-> ArithmeticFunction n a -> ArithmeticFunction n b
<*> (ArithmeticFunction Prime n -> Word -> m
f2 m -> a
g2)
    = (Prime n -> Word -> (m, m))
-> ((m, m) -> b) -> ArithmeticFunction n b
forall m n a.
Monoid m =>
(Prime n -> Word -> m) -> (m -> a) -> ArithmeticFunction n a
ArithmeticFunction (\Prime n
p Word
k -> (Prime n -> Word -> m
f1 Prime n
p Word
k, Prime n -> Word -> m
f2 Prime n
p Word
k)) (\(m
a1, m
a2) -> m -> a -> b
g1 m
a1 (m -> a
g2 m
a2))

instance Semigroup a => Semigroup (ArithmeticFunction n a) where
  <> :: ArithmeticFunction n a
-> ArithmeticFunction n a -> ArithmeticFunction n a
(<>) = (a -> a -> a)
-> ArithmeticFunction n a
-> ArithmeticFunction n a
-> ArithmeticFunction n a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Semigroup a => a -> a -> a
(<>)

instance Monoid a => Monoid (ArithmeticFunction n a) where
  mempty :: ArithmeticFunction n a
mempty  = a -> ArithmeticFunction n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Monoid a => a
mempty
#if __GLASGOW_HASKELL__ < 803
  mappend = liftA2 mappend
#else
  mappend :: ArithmeticFunction n a
-> ArithmeticFunction n a -> ArithmeticFunction n a
mappend = ArithmeticFunction n a
-> ArithmeticFunction n a -> ArithmeticFunction n a
forall a. Semigroup a => a -> a -> a
(<>)
#endif

-- | Factorisation is expensive, so it is better to avoid doing it twice.
-- Write 'runFunction (f + g) n' instead of 'runFunction f n + runFunction g n'.
instance Num a => Num (ArithmeticFunction n a) where
  fromInteger :: Integer -> ArithmeticFunction n a
fromInteger = a -> ArithmeticFunction n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> ArithmeticFunction n a)
-> (Integer -> a) -> Integer -> ArithmeticFunction n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
  negate :: ArithmeticFunction n a -> ArithmeticFunction n a
negate = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate
  signum :: ArithmeticFunction n a -> ArithmeticFunction n a
signum = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum
  abs :: ArithmeticFunction n a -> ArithmeticFunction n a
abs    = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
  + :: ArithmeticFunction n a
-> ArithmeticFunction n a -> ArithmeticFunction n a
(+) = (a -> a -> a)
-> ArithmeticFunction n a
-> ArithmeticFunction n a
-> ArithmeticFunction n a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(+)
  (-) = (a -> a -> a)
-> ArithmeticFunction n a
-> ArithmeticFunction n a
-> ArithmeticFunction n a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
  * :: ArithmeticFunction n a
-> ArithmeticFunction n a -> ArithmeticFunction n a
(*) = (a -> a -> a)
-> ArithmeticFunction n a
-> ArithmeticFunction n a
-> ArithmeticFunction n a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(*)

instance Fractional a => Fractional (ArithmeticFunction n a) where
  fromRational :: Rational -> ArithmeticFunction n a
fromRational = a -> ArithmeticFunction n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> ArithmeticFunction n a)
-> (Rational -> a) -> Rational -> ArithmeticFunction n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational
  recip :: ArithmeticFunction n a -> ArithmeticFunction n a
recip = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Fractional a => a -> a
recip
  / :: ArithmeticFunction n a
-> ArithmeticFunction n a -> ArithmeticFunction n a
(/) = (a -> a -> a)
-> ArithmeticFunction n a
-> ArithmeticFunction n a
-> ArithmeticFunction n a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Fractional a => a -> a -> a
(/)

instance Floating a => Floating (ArithmeticFunction n a) where
  pi :: ArithmeticFunction n a
pi    = a -> ArithmeticFunction n a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Floating a => a
pi
  exp :: ArithmeticFunction n a -> ArithmeticFunction n a
exp   = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
exp
  log :: ArithmeticFunction n a -> ArithmeticFunction n a
log   = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
log
  sin :: ArithmeticFunction n a -> ArithmeticFunction n a
sin   = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sin
  cos :: ArithmeticFunction n a -> ArithmeticFunction n a
cos   = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
cos
  asin :: ArithmeticFunction n a -> ArithmeticFunction n a
asin  = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
asin
  acos :: ArithmeticFunction n a -> ArithmeticFunction n a
acos  = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
acos
  atan :: ArithmeticFunction n a -> ArithmeticFunction n a
atan  = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
atan
  sinh :: ArithmeticFunction n a -> ArithmeticFunction n a
sinh  = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sinh
  cosh :: ArithmeticFunction n a -> ArithmeticFunction n a
cosh  = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
cosh
  asinh :: ArithmeticFunction n a -> ArithmeticFunction n a
asinh = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
asinh
  acosh :: ArithmeticFunction n a -> ArithmeticFunction n a
acosh = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
acosh
  atanh :: ArithmeticFunction n a -> ArithmeticFunction n a
atanh = (a -> a) -> ArithmeticFunction n a -> ArithmeticFunction n a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
atanh