{-# LANGUAGE OverlappingInstances, UndecidableInstances, StandaloneDeriving #-}

module Feldspar.Transformation.Framework where

import Feldspar.Compiler.Error
import Feldspar.Compiler.Imperative.Representation

transformationError = handleError "PluginArch/TransformationFramework" InternalError

-- ===========
-- == Utils ==
-- ===========

class Default t where
    def :: t
    def = transformationError "Default value requested."

class Combine t where
    combine :: t -> t -> t
    combine = transformationError "Default combination function used."

class Convert a b where
    convert :: a -> b

instance Default () where
    def = ()

instance Default [a] where
    def = []

instance Default Int where
    def = 0

instance (Default a, Default b) => Default (a,b) where
    def = (def, def)

instance Combine () where
    combine _ _ = ()

instance Combine String where
    combine s1 s2 = s1 ++ s2

instance Combine Int where
    combine i1 i2 = i1 + i2     

instance (Combine a, Combine b) 
    => Combine (a,b) where
        combine (x,y) (v,w) = (combine x v, combine y w)    

instance Default b => Convert a b where
    convert _ = def

-- =============================
-- == TransformationFramework ==
-- =============================

class (Default (Up t), Combine (Up t))
    => Transformation t where
        type From t
        type To t

        type State t
        type Down t
        type Up t

data (Transformation t)
    => Result t s
        = Result
        { result    :: s (To t)
        , state     :: State t
        , up        :: Up t
        }

deriving instance (Transformation t, Show (s (To t)), Show (State t), Show (Up t)) => Show (Result t s)

data (Transformation t)
    => Result1 t s a
        = Result1
        { result1   :: s (a (To t))
        , state1    :: State t
        , up1       :: Up t
        }

deriving instance (Transformation t, Show (s (b (To t))), Show (State t), Show (Up t)) => Show (Result1 t s b)

class (Transformation t)
    => Transformable t s where
        transform :: t -> State t -> Down t -> s (From t) -> Result t s

class (Transformation t)
    => Transformable1 t s a where
        transform1 :: t -> State t -> Down t -> s (a (From t)) -> Result1 t s a

class (Transformation t)
    => DefaultTransformable t s where
        defaultTransform :: t -> State t -> Down t -> s (From t) -> Result t s

class (Transformation t)
    => DefaultTransformable1 t s a where
        defaultTransform1 :: t -> State t -> Down t -> s (a (From t)) -> Result1 t s a

instance (DefaultTransformable t s)
    => Transformable t s where
        transform = defaultTransform

instance (DefaultTransformable1 t s a)
    => Transformable1 t s a where
        transform1 = defaultTransform1