{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns -fno-warn-orphans #-}
{-
  Compile with: $ghc --make -threaded -O2 main.hs
  Execute with: $./main +RTS -A128M -s -c -N4 -RTS 5000000 500

  processing x
  main: internal error: removeThreadFromQueue: not found
      (GHC version 6.8.2 for i386_unknown_linux)
      Please report this as a GHC bug:  http://www.haskell.org/ghc/reportabug  

  The value of x is not always the same, and it depends on the values
  used, including RTS settings.

  Using only 1 thread, it works.
  Increasing the thread stack size (-k1M), it works.

  I suspect the problem is caused by an uncaught stack overflow.
-}
module Main where

import Control.Arrow (first)
import Control.Monad (liftM)
import Control.Parallel (pseq)
import Control.Parallel.Strategies
import Data.Array.Vector ((:*:)(..), UA, UArr, appendU, foldlU, lengthU, singletonU, snocU, toU)
import Data.Foldable (foldl')
import Data.Int (Int8, Int16)
import Data.IntMap (IntMap, alter, empty, unionsWith, toAscList)
import Debug.Trace (trace)
import System.Console.GetOpt
import System.Environment (getArgs)
import System.Random (RandomGen, mkStdGen, randomR, randomRs)


type A = Int16
type B = Int8


invert :: [(A, UArr B)] -> [(B, UArr A)]
invert = convert . result . split'
    where
      convert :: IntMap (UArr A) -> [(B, UArr A)]
      convert = map (first fromIntegral) . toAscList

      split' xs = split (length xs `div` 4) xs

      -- XXX unions is not strict
      result = mapReduce rnf invert' rnf (unionsWith appendU)

invert' :: [(A, UArr B)] -> IntMap (UArr A)
invert' = foldl' invert'' empty
    where
      invert'' :: IntMap (UArr A) -> (A, UArr B) -> IntMap (UArr A)
      invert'' m (a, bs) = trace msg foldlU accum m bs
          where
            msg :: String
            msg = "processing " ++ show a

            accum :: IntMap (UArr A) -> B -> IntMap (UArr A)
            accum m' b =
                a `seq` b `seq` alter append (fromIntegral b) m'
                where
                  append Nothing = Just $ singletonU a
                  append (Just u) = Just $! snocU u a


parse :: [String] -> (Int, Int)
parse argv =
    case getOpt Permute [] argv of
      ([], [n, m], []) -> (read n, read m)
      ([], [], [])     -> (100, 10)
      (_, _, errs)     -> die errs
    where
      header   = "Usage: main [n m]"
      info     = usageInfo header []
      die errs = error $ concat errs ++ info


main :: IO ()
main = do
  (n, m) <- liftM parse getArgs

  let
      -- Create the input data, using random numbers
      gen = mkStdGen 777

      stream :: [B]
      stream = map fromIntegral $ randomRs (0, m) gen

      u :: [UArr B]
      u = map toU $ randomPartition (take n stream) n m gen

      input :: [(A, UArr B)]
      input = zip (enumFrom 1) u

  print $ length input
  print $ length $ invert input


--
-- Support functions and instances
--

-- |Split a list into a list of lists, each having length @n@.
--
-- Code originally written by:
-- Daniel Peebles (aka pumpkin).
split :: Int -> [a] -> [[a]]
split n = takeWhile (not . null) . map (take n) . iterate (drop n)

-- |Given a list, its length, the number of subdivisions and a random
-- number generator, compute a random partition of the input.
--
-- /A subpart may be empty./
--
-- Code adapted from:
-- <http://hpaste.org/fastcgi/hpaste.fcgi/view?id=2485#a249>.
randomPartition :: RandomGen gen => [a] -> Int -> Int -> gen -> [[a]]
randomPartition [] 0 m _       = replicate m []
randomPartition xs _ 1 _       = [xs]
randomPartition (x : xs) n m g = result where
    (t, g') = randomR (1, n + m - 1) g
    result | t < m     = [] : randomPartition (x : xs) n (m - 1) g'
           | otherwise = mapHead (x :) (randomPartition xs (n - 1) m g')
    mapHead f (y : ys) = f y : ys

mapReduce :: Strategy b    -- evaluation strategy for mapping
          -> (a -> b)      -- map function
          -> Strategy c    -- evaluation strategy for reduction
          -> ([b] -> c)    -- reduce function
          -> [a]           -- list to map over
          -> c

mapReduce mapStrat mapFunc reduceStrat reduceFunc input =
    mapResult `pseq` reduceResult
  where mapResult    = parMap mapStrat mapFunc input
        reduceResult = reduceFunc mapResult `using` reduceStrat


-- XXX these are missing from uvector package
instance (NFData a, NFData b) => NFData (a :*: b) where
    -- NOTE: (:*:) is already strict
    rnf (a :*: b) = rnf a `seq` rnf b `seq` ()

instance (NFData a, UA a) => NFData (UArr a) where
    -- NOTE: UArr is already strict
    rnf arr = lengthU arr `seq` ()
