-- Copyright (c) David Amos, 2008. All rights reserved.

{-# OPTIONS_GHC -fglasgow-exts #-}

-- |A module providing a type for (commutative) multivariate polynomials, with support for various term orders.
module Math.Algebra.Commutative.MPoly where

import qualified Data.Map as M
import Data.List as L
import Control.Arrow (first, second)
import Data.Ratio (denominator)

import Math.Algebra.Field.Base
import Math.Algebra.Commutative.Monomial


-- MULTIVARIATE POLYNOMIALS

-- |Type for multivariate polynomials.
-- ord is a phantom type defining how terms are ordered, r is the type of the ring we are working over.
-- For example, a common choice will be MPoly Grevlex Q, meaning polynomials over Q with the grevlex term ordering
newtype MPoly ord r = MP [(Monomial ord,r)] deriving (Eq)
-- deriving instance (Ord (Monomial ord), Ord r) => Ord (MPoly ord r)
-- standalone deriving supported from GHC 6.8

instance (Ord (Monomial ord), Ord r) => Ord (MPoly ord r) where
    compare (MP ts) (MP us) = compare ts us

instance (Show r, Num r) => Show (MPoly ord r) where
    show (MP []) = "0"
    show (MP ts) =
        let (c:cs) = concatMap showTerm ts
        in if c == '+' then cs else c:cs
        where showTerm (m,c) =
                  case show c of
                  "1" -> "+" ++ show m
                  "-1" -> "-" ++ show m
                  cs@(x:_) -> (if x == '-' then cs else '+':cs) ++ (if m == 1 then "" else show m)


instance (Ord (Monomial ord), Num r) => Num (MPoly ord r) where
    MP ts + MP us = MP (mergeTerms ts us)
    negate (MP ts) = MP $ map (second negate) ts
    MP ts * MP us = MP $ collect $ sortBy cmpTerm $ [(g*h,c*d) | (g,c) <- ts, (h,d) <- us]
    {-
    -- The following appears to be slightly slower, perhaps because sortBy is compiled
    MP (t@(g,c):ts) * MP (u@(h,d):us) =
        let MP vs = MP ts * MP us
        in MP $ mergeTerms ((g*h,c*d):vs) $ mergeTerms [(g*h,c*d) | (h,d) <- us] [(g*h,c*d) | (g,c) <- ts]
    _ * _ = MP []
    -}
    fromInteger 0 = MP []
    fromInteger n = MP [(fromInteger 1, fromInteger n)]

cmpTerm (a,c) (b,d) = case compare a b of EQ -> EQ; GT -> LT; LT -> GT -- in mpolys we put "larger" terms first

-- inputs in descending order
mergeTerms (t@(g,c):ts) (u@(h,d):us) =
    case compare g h of
    GT -> t : mergeTerms ts (u:us)
    LT -> u : mergeTerms (t:ts) us
    EQ -> if e == 0 then mergeTerms ts us else (g,e) : mergeTerms ts us
    where e = c + d
mergeTerms ts us = ts ++ us -- one of them is null

collect (t1@(g,c):t2@(h,d):ts)
    | g == h = collect $ (g,c+d):ts
    | c == 0  = collect $ t2:ts
    | otherwise = t1 : collect (t2:ts)
collect ts = ts

-- Fractional instance so that we can enter fractional coefficients
-- Only lets us divide by field elements (with unit monomial), not any other polynomials
instance (Ord (Monomial ord), Fractional r) => Fractional (MPoly ord r) where
    recip (MP [(m,c)]) = MP [(recip m, recip c)]
    -- recip (MP [(m,c)]) | m == fromInteger 1 = MP [(m, recip c)]
    recip _ = error "MPoly.recip: only supported for (non-zero) constants or monomials"

-- |Create a variable with the supplied name.
-- By convention, variable names should usually be a single letter followed by none, one or two digits.
var :: String -> MPoly Grevlex Q
var v = MP [(Monomial $ M.singleton v 1, 1)] :: MPoly Grevlex Q

a, b, c, d, s, t, u, v, w, x, y, z :: MPoly Grevlex Q
a = var "a"
b = var "b"
c = var "c"
d = var "d"
s = var "s"
t = var "t"
u = var "u"
v = var "v"
w = var "w"
x = var "x"
y = var "y"
z = var "z"

x_ i = var ("x" ++ show i)

x0, x1, x2, x3 :: MPoly Grevlex Q
x0 = x_ 0
x1 = x_ 1
x2 = x_ 2
x3 = x_ 3


-- convertMP :: Ord (Monomial ord') => MPoly ord k -> MPoly ord' k
convertMP (MP ts) = MP $ sortBy cmpTerm $ map (first convertM) ts

-- |Convert a polynomial to lex term ordering
toLex :: MPoly ord k -> MPoly Lex k
toLex = convertMP

-- |Convert a polynomial to glex term ordering
toGlex :: MPoly ord k -> MPoly Glex k
toGlex = convertMP

-- |Convert a polynomial to grevlex term ordering
toGrevlex :: MPoly ord k -> MPoly Grevlex k
toGrevlex = convertMP

toElim :: MPoly ord k -> MPoly Elim k
toElim = convertMP


varLex v = toLex $ var v

varElim v = toElim $ var v


-- DIVISION ALGORITHM

lt (MP (t:ts)) = t

lm = fst . lt

deg 0 = -1
deg (MP ts) = maximum [degM m | (m,c) <- ts]
-- the true degree of the polynomial, not the degree of the leading term
-- required for sugar strategy when computing Groebner basis

mulT (m,c) (m',c') = (m*m',c*c')

divT (m,c) (m',c') = (m/m',c/c')

dividesT (m,_) (m',_) = dividesM m m'

properlyDividesT (m,_) (m',_) = dividesM m m' && m /= m'

lcmT (m,c) (m',c') = (lcmM m m',1)


infixl 7 .*
t .* MP ts = MP $ map (mulT t) ts -- preserves term order


-- given f, gs, find as, r such that f = sum (zipWith (*) as gs) + r, with r not divisible by any g
quotRemMP f gs = quotRemMP' f (replicate n 0, 0) where
    n = length gs
    quotRemMP' 0 (us,r) = (us,r)
    quotRemMP' h (us,r) = divisionStep h (gs,[],us,r)
    divisionStep h (g:gs,us',u:us,r) =
        if lt g `dividesT` lt h
        then let t = MP [lt h `divT` lt g]
                 h' = h - t*g
                 u' = u+t
             in quotRemMP' h' (reverse us' ++ u':us, r)
        else divisionStep h (gs,u:us',us,r)
    divisionStep h ([],us',[],r) =
        let (lth,h') = splitlt h
        in quotRemMP' h' (reverse us', r+lth)
    splitlt (MP (t:ts)) = (MP [t], MP ts)

infixl 7 %%
f %% gs = r where (_,r) = quotRemMP f gs

-- div and mod by single mpoly
divModMP f g = (q,r) where ([q],r) = quotRemMP f [g]

divMP f g = q where ([q],_) = quotRemMP f [g]

modMP f g = r where (_,r) = quotRemMP f [g]


-- OTHER STUFF

-- injection of field elements into polynomial ring
inject 0 = MP []
inject c = MP [(fromInteger 1, c)]

toMonic 0 = 0
toMonic (MP ts@((_,c):_))
    | c == 1 = MP ts
    | otherwise = MP $ map (second (/c)) ts

-- multiply out all denominators
toZ (MP ts) = MP $ map (second (*c)) ts where c = fromInteger $ foldl lcm 1 $ [denominator c | (m,Q c) <- ts]

-- substitute terms for variables in an MPoly
-- eg subst [(x,a),(y,a+b),(z,c^2)] (x*y+z) -> a*(a+b)+c^2
subst vts (MP us) = sum [inject c * substM m | (m,c) <- us] where
    substM (Monomial m) = product [substV v ^ i | (v,i) <- M.toList m]
    substV v =
        let v' = MP [(Monomial $ M.singleton v 1, 1)] in
        case L.lookup v' vts of
        Just t -> t
        Nothing -> v' -- no substitute, so keep as is

support (MP ts) = [MP [(m,1)] | m <- reverse $ L.sort $ support' ts]
    where support' ts = foldl L.union [] [supportM m | (m,c) <- ts]