{- |
    This module provides the 'Str' data type, which is used by the
    underlying 'uniplate' and 'biplate' methods. It should not
    be used directly under normal circumstances.

module Data.Generics.Str where

import Data.Generics.PlateInternal

import Data.Traversable
import Data.Foldable
import Control.Applicative
import Data.Monoid

-- * The Data Type

data Str a = Zero | One a | Two (Str a) (Str a)
             deriving Show

instance Eq a => Eq (Str a) where
    Zero == Zero = True
    One x == One y = x == y
    Two x1 x2 == Two y1 y2 = x1 == y1 && x2 == y2
    _ == _ = False

instance Functor Str where
  fmap f Zero = Zero
  fmap f (One x) = One (f x)
  fmap f (Two x y) = Two (fmap f x) (fmap f y)

instance Foldable Str where
  foldMap m Zero = mempty
  foldMap m (One x) = m x
  foldMap m (Two l r) = foldMap m l `mappend` foldMap m r

instance Traversable Str where
  traverse f Zero = pure Zero
  traverse f (One x) = One <$> f x
  traverse f (Two x y) = Two <$> traverse f x <*> traverse f y

-- | Take the type of the method, will crash if called
strType :: Str a -> a
strType = error "Data.Generics.Str.strType: Cannot be called"

-- | Convert a 'Str' to a list, assumes the value was created
--   with 'listStr'
strList :: Str a -> [a]
strList x = builder (f x)
        f (Two (One x) xs) cons nil = x `cons` f xs cons nil
        f Zero cons nil = nil

-- | Convert a list to a 'Str'
listStr :: [a] -> Str a
listStr (x:xs) = Two (One x) (listStr xs)
listStr [] = Zero

-- | Transform a 'Str' to a list, and back again, in a structure
--   preserving way. The output and input lists must be equal in
--   length.
strStructure :: Str a -> ([a], [a] -> Str a)
strStructure x = (g x [], fst . f x)
        g :: Str a -> [a] -> [a]
        g Zero xs = xs
        g (One x) xs = x:xs
        g (Two a b) xs = g a (g b xs)

        f :: Str a -> [a] -> (Str a, [a])
        f Zero rs = (Zero, rs)
        f (One _) (r:rs) = (One r, rs)
        f (Two a b) rs1 = (Two a2 b2, rs3)
                (a2,rs2) = f a rs1
                (b2,rs3) = f b rs2