-- | \"Open\" functions, working on functors instead of trees.

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Open 
  ( 
    toList
  -- * Accumulating maps
  , mapAccumL  , mapAccumR
  , mapAccumL_ , mapAccumR_
  -- * Open functions
  , holes , holesList
  , apply , builder
  -- * Enumerations
  , enumerate
  , enumerateWith
  , enumerateWith_
  )
where

--------------------------------------------------------------------------------

import Control.Monad (liftM)
import Data.Foldable
import Data.Traversable ( Traversable(..) , mapAccumL , mapAccumR )
import Prelude hiding (foldl,foldr,mapM,mapM_,concat,concatMap)

import Data.Generics.Fixplate.Base 
import Data.Generics.Fixplate.Misc

--------------------------------------------------------------------------------
-- Accumulating maps

mapAccumL_ :: Traversable f => (a -> b -> (a, c)) -> a -> f b -> f c
mapAccumL_ f x t = snd (mapAccumL f x t)

mapAccumR_ :: Traversable f => (a -> b -> (a, c)) -> a -> f b -> f c
mapAccumR_ f x t = snd (mapAccumR f x t)

--------------------------------------------------------------------------------
-- Open functions

-- | The children together with functions replacing that particular child.    
holes :: Traversable f => f a -> f (a, a -> f a)
holes tree = mapAccumL_ ithHole 0 tree where
  ithHole i x = (i+1, (x,h)) where          
    h y = mapAccumL_ g 0 tree where         
      g j z = (j+1, if i==j then y else z)  

holesList :: Traversable f => f a -> [(a, a -> f a)]
holesList = toList . holes

-- | Apply the given function to each child in turn.
apply :: Traversable f => (a -> a) -> f a -> f (f a)
apply f tree = fmap g (holes tree) where
  g (x,replace) = replace (f x)

-- | Builds up a structure from a list of the children.
builder :: Traversable f => f a -> [b] -> f b
builder tree xs = mapAccumL_ g xs tree where
  g (x:xs) _ = (xs,x)

--------------------------------------------------------------------------------
-- Enumerations

-- | Enumerates children from the left to the right, starting with zero.
-- Also returns the number of children. This is just a simple application
-- of @mapAccumL@.
enumerate :: Traversable f => f a -> (Int, f (Int, a))
enumerate = mapAccumL (\i x -> (i+1,(i,x))) 0

enumerateWith :: Traversable f => (Int -> a -> b) -> f a -> (Int, f b)
enumerateWith h = mapAccumL (\i x -> (i+1, h i x)) 0

enumerateWith_ :: Traversable f => (Int -> a -> b) -> f a -> f b
enumerateWith_ h = snd . enumerateWith h
  
--------------------------------------------------------------------------------