module Math.Algebra.NonCommutative.NCPoly where
import Data.List as L
import Math.Algebra.Field.Base
newtype Monomial v = M [v] deriving (Eq)
instance Ord v => Ord (Monomial v) where
    compare (M xs) (M ys) = compare (length xs,xs) (length ys,ys)
instance (Eq v, Show v) => Show (Monomial v) where
    show (M xs) | null xs = "1"
                | otherwise = concatMap showPower (L.group xs)
        where showPower [v] = showVar v
              showPower vs@(v:_) = showVar v ++ "^" ++ show (length vs)
              showVar v = filter (/= '"') (show v)
instance (Eq v, Show v) => Num (Monomial v) where
    M xs * M ys = M (xs ++ ys)
    fromInteger 1 = M []
divM (M a) (M b) = divM' [] a where
    divM' ls (r:rs) =
        if b `L.isPrefixOf` (r:rs)
        then Just (M $ reverse ls, M $ drop (length b) (r:rs))
        else divM' (r:ls) rs
    divM' _ [] = Nothing
newtype NPoly r v = NP [(Monomial v,r)] deriving (Eq)
instance (Ord r, Ord v) => Ord (NPoly r v) where
    compare (NP ts) (NP us) = compare ts us
instance (Show r, Eq v, Show v) => Show (NPoly r v) where
    show (NP []) = "0"
    show (NP ts) =
        let (c:cs) = concatMap showTerm ts
        in if c == '+' then cs else c:cs
        where showTerm (m,a) =
                  case show a of
                  "1" -> "+" ++ show m
                  "-1" -> "-" ++ show m
                  
                  cs -> showCoeff cs ++ (if m == 1 then "" else show m)
              showCoeff (c:cs) = if any (`elem` ['+','-']) cs
                                 then "+(" ++ c:cs ++ ")"
                                 else if c == '-' then c:cs else '+':c:cs
instance (Eq r, Num r, Ord v, Show v) => Num (NPoly r v) where
    NP ts + NP us = NP (mergeTerms ts us)
    negate (NP ts) = NP $ map (\(m,c) -> (m,c)) ts
    NP ts * NP us = NP $ collect $ L.sortBy cmpTerm $ [(g*h,c*d) | (g,c) <- ts, (h,d) <- us]
    fromInteger 0 = NP []
    fromInteger n = NP [(fromInteger 1, fromInteger n)]
cmpTerm (a,c) (b,d) = case compare a b of EQ -> EQ; GT -> LT; LT -> GT 
mergeTerms (t@(g,c):ts) (u@(h,d):us) =
    case cmpTerm t u of
    LT -> t : mergeTerms ts (u:us)
    GT -> u : mergeTerms (t:ts) us
    EQ -> if e == 0 then mergeTerms ts us else (g,e) : mergeTerms ts us
    where e = c + d
mergeTerms ts us = ts ++ us 
collect (t1@(g,c):t2@(h,d):ts)
    | g == h = collect $ (g,c+d):ts
    | c == 0  = collect $ t2:ts
    | otherwise = t1 : collect (t2:ts)
collect ts = ts
instance (Eq k, Fractional k, Ord v, Show v) => Fractional (NPoly k v) where
    recip (NP [(1,c)]) = NP [(1, recip c)]
    recip _ = error "NPoly.recip: only supported for (non-zero) constants"
data Var = X | Y | Z deriving (Eq,Ord)
instance Show Var where
    show X = "x"
    show Y = "y"
    show Z = "z"
var :: (Num k) => v -> NPoly k v
var v = NP [(M [v], 1)]
x = var X :: NPoly Q Var
y = var Y :: NPoly Q Var
z = var Z :: NPoly Q Var
lm (NP ((m,c):ts)) = m
lc (NP ((m,c):ts)) = c
lt (NP (t:ts)) = NP [t]
quotRemNP f gs | all (/=0) gs = quotRemNP' f (replicate n (0,0), 0)
               | otherwise = error "quotRemNP: division by zero"
    where
    n = length gs
    quotRemNP' 0 (lrs,f') = (lrs,f')
    quotRemNP' h (lrs,f') = divisionStep h (gs,[],lrs,f')
    divisionStep h (g:gs, lrs', (l,r):lrs, f') =
        case lm h `divM` lm g of
        Just (l',r') -> let l'' = NP [(l',lc h / lc g)]
                            r'' = NP [(r',1)]
                            h' = h  l'' * g * r''
                        in quotRemNP' h' (reverse lrs' ++ (l+l'',r+r''):lrs, f')
        Nothing -> divisionStep h (gs,(l,r):lrs',lrs,f')
    divisionStep h ([],lrs',[],f') =
        let lth = lt h 
        in quotRemNP' (hlth) (reverse lrs', f'+lth)
remNP f gs | all (/=0) gs = remNP' f 0
           | otherwise = error "remNP: division by zero"
    where
    n = length gs
    remNP' 0 f' = f'
    remNP' h f' = divisionStep h gs f'
    divisionStep h (g:gs) f' =
        case lm h `divM` lm g of
        Just (l',r') -> let l'' = NP [(l',lc h / lc g)]
                            r'' = NP [(r',1)]
                            h' = h  l'' * g * r''
                        in remNP' h' f'
        Nothing -> divisionStep h gs f'
    divisionStep h [] f' =
        let lth = lt h 
        in remNP' (hlth) (f'+lth)
infixl 7 %%
f %% gs = remNP f gs
remNP2 f gs | all (/=0) gs = remNP' f 0
           | otherwise = error "remNP: division by zero"
    where
    n = length gs
    remNP' 0 f' = f'
    remNP' h f' = divisionStep h gs f'
    divisionStep h (g:gs) f' =
        case lm h `divM` lm g of
        Just (l',r') -> let l'' = NP [(l',1)] 
                            r'' = NP [(r',1)]
                            lcg = inject (lc g)
                            lch = inject (lc h)
                            
                            h' = lcg * h  lch * l'' * g * r''
                        in remNP' h' (lcg * f') 
        Nothing -> divisionStep h gs f'
    divisionStep h [] f' =
        let lth = lt h 
        in remNP' (hlth) (f'+lth)
toMonic 0 = 0
toMonic (NP ts@((_,c):_))
    | c == 1 = NP ts
    | otherwise = NP $ map (\(m,d)->(m,d/c)) ts
inject 0 = NP []
inject c = NP [(fromInteger 1, c)]
subst vts (NP us) = sum [inject c * substM m | (m,c) <- us] where
    substM (M xs) = product [substV x | x <- xs]
    substV v =
        let v' = NP [(M [v], 1)] in
        case L.lookup v' vts of
        Just t -> t
        Nothing -> error ("subst: no substitute supplied for " ++ show v')
class Invertible a where
    inv :: a -> a
x ^- k = inv x ^ k