{-# LANGUAGE 
        MultiParamTypeClasses, 
        TypeSynonymInstances, 
        FlexibleContexts #-}

-- | Arithmetic operators

module CsoundExpr.Base.Arithmetic 
	(Opr2, (<+>), neg, (<->), (<*>), (</>))
where

import Prelude hiding (div)

import Control.Applicative(liftA2, (<$>))
import Data.Function(on)

import CsoundExpr.Base.Types
import CsoundExpr.Translator.Cs.CsTree
import CsoundExpr.Translator.ExprTree.ExprTree
import CsoundExpr.Base.UserDefined


infixr 7  <*>, </>
infixr 6  <+>, <->


subst :: (IM d a, IM d b, IM d c) => (d -> d -> c) -> (a -> b -> c)
subst f x y = f (to x) (to y)

opr1 :: IM CsTree a => Name -> (Double -> Double) -> CsTree -> a
opr1 name fun a = maybe (infixOperation name $ return a) 
			from (optim1 fun a)

opr1p :: IM CsTree a => Name -> (Double -> Double) -> CsTree -> a
opr1p name fun a = maybe (prefixOperation name $ return a)
			 from (optim1 fun a)	

opr2 :: IM CsTree a => Name -> (Double -> Double -> Double) -> CsTree -> CsTree -> a
opr2 name fun a b = maybe (infixOperation name [a, b])
			  from (optim2 fun a b)

optim1 :: (Double -> Double) -> CsTree -> Maybe CsTree
optim1 fun a
	| isVal a'  = fmap (double . fun) $ toDouble $ value a'
 	| otherwise = Nothing
	where a' = exprOp $ exprTreeTag a

optim2 :: (Double -> Double -> Double) -> CsTree -> CsTree -> Maybe CsTree
optim2 fun a b  
	| isVal a' && isVal b' = fmap double $
				 (liftA2 fun `on` toDouble . value) a' b' 
	| otherwise            = Nothing
	where a' = exprOp $ exprTreeTag a 
	      b' = exprOp $ exprTreeTag b	


----------------------------------------------------------
-- Type inference

(<+>), (<->), (<*>), (</>) :: (X a, X b, X (Opr2 a b)) => a -> b -> Opr2 a b

(<+>) = subst $ opr2 "+" (+)
(<->) = subst $ opr2 "-" (-)
(<*>) = subst $ opr2 "*" (*)
(</>) = subst $ opr2 "/" (/)

-- | negation
neg :: (X a) => a -> a
neg = opr1 "-" negate . to

----------------------------------------------------------
-- Irates


instance Num Irate where
    (+) = (<+>)
    (*) = (<*>)
    abs = opr1p "abs" abs . to
    signum = error "is undefined"
    (-) = (<->)
    fromInteger = double . fromInteger


instance Fractional Irate where
    (/) = (</>)
    fromRational = double . fromRational