{-# LANGUAGE  MultiParamTypeClasses
            , FlexibleContexts
            , FlexibleInstances
            , UndecidableInstances
            , ScopedTypeVariables
            #-}

module MWE2 () where

data SumsPart t
    = IntLit Int
    | Add t t
    | Sub t t
  deriving (Eq, Ord, Show)
data MultPart t
    = Mul t t
  deriving (Eq, Ord, Show)

type SumsAst = Ast1 SumsPart
type MultAst = Ast2 SumsPart MultPart

data Ast1 p1 = Ast1p1 (p1 (Ast1 p1))
data Ast2 p1 p2 = Ast2p1 (p1 (Ast2 p1 p2)) | Ast2p2 (p2 (Ast2 p1 p2))

class AstOp op ast result where
  astop :: op -> ast -> result
class AstStep op part ast result where
  aststep :: op -> part ast -> result
class AstWrap part ast where
  astwrap :: part ast -> ast

instance AstStep op p1 (Ast1 p1) result =>
         AstOp op (Ast1 p1) result where
    astop op ast = case ast of Ast1p1 p -> aststep op p
instance (AstStep op p1 (Ast2 p1 p2) result,
          AstStep op p2 (Ast2 p1 p2) result) =>
         AstOp op (Ast2 p1 p2) result where
    astop op ast
        = case ast of
            Ast2p1 p -> aststep op p
            Ast2p2 p -> aststep op p
instance AstWrap p1 (Ast1 p1) where
     astwrap = Ast1p1 
instance AstWrap p1 (Ast2 p1 p2) where
     astwrap = Ast2p1 
instance AstWrap p2 (Ast2 p1 p2) where
     astwrap = Ast2p2 

data HomOp = HomOp
instance (AstOp HomOp ast1 ((ast1 -> ast2) -> ast2)
         ,AstWrap SumsPart ast2)
      => AstStep HomOp SumsPart ast1 ((ast1 -> ast2) -> ast2) where
  aststep HomOp p = \f -> astwrap $
    case p of
      IntLit n -> IntLit n
      Add e1 e2 -> Add (f e1) (f e2)
      Sub e1 e2 -> Sub (f e1) (f e2)
instance (AstOp HomOp ast1 ((ast1 -> ast2) -> ast2)
         ,AstWrap MultPart ast2)
      => AstStep HomOp MultPart ast1 ((ast1 -> ast2) -> ast2) where
  aststep HomOp p = \f -> astwrap $
    case p of
      Mul e1 e2 -> Mul (f e1) (f e2)

-- NOTE: added type signature here
upcast :: (AstOp HomOp ast1 ((ast1 -> ast2) -> ast2)) => ast1 -> ast2
-- NOTE: added type annotation for upcast, but still fails to compile
upcast ast = astop HomOp ast (upcast :: ast1 -> ast2)

example :: MultAst
example = upcast $ Ast1p1 $ IntLit 3
