{-# language ImportQualifiedPost #-}
{-# language ViewPatterns #-}
{-# language OverloadedStrings #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SRTree.Datasets
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  FlexibleInstances, DeriveFunctor, ScopedTypeVariables, ConstraintKinds
--
-- Utility library to handle regression datasets
-- this module exports only the `loadDataset` function.
--
-----------------------------------------------------------------------------
module Data.SRTree.Datasets ( loadDataset )
    where

import Codec.Compression.GZip (decompress)
import Data.ByteString.Char8 qualified as B
import Data.ByteString.Lazy qualified as BS
import Data.List (delete, find, intercalate)
import Data.Massiv.Array
  ( Array,
    Comp (Seq, Par),
    Ix2 ((:.)),
    S (..),
    Sz (Sz1),
    (<!),
  )
import Data.Massiv.Array qualified as M
import Data.Maybe (fromJust)
import Data.SRTree.Eval (PVector, SRMatrix, compMode)
import Data.Vector qualified as V
import System.FilePath (takeExtension)
import Text.Read (readMaybe)

-- | Loads a list of list of bytestrings to a matrix of double
loadMtx :: [[B.ByteString]] -> Array S Ix2 Double
loadMtx :: [[ByteString]] -> Array S Ix2 Double
loadMtx = Comp -> [ListItem Ix2 Double] -> Array S Ix2 Double
forall r ix e.
(HasCallStack, Ragged L ix e, Manifest r e) =>
Comp -> [ListItem ix e] -> Array r ix e
M.fromLists' Comp
compMode ([[Double]] -> Array S Ix2 Double)
-> ([[ByteString]] -> [[Double]])
-> [[ByteString]]
-> Array S Ix2 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([ByteString] -> [Double]) -> [[ByteString]] -> [[Double]]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> Double) -> [ByteString] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> Double
forall a. Read a => [Char] -> a
read ([Char] -> Double)
-> (ByteString -> [Char]) -> ByteString -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
B.unpack))
{-# INLINE loadMtx #-}

-- | Returns true if the extension is .gz
isGZip :: FilePath -> Bool
isGZip :: [Char] -> Bool
isGZip = ([Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
".gz") ([Char] -> Bool) -> ([Char] -> [Char]) -> [Char] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [Char]
takeExtension
{-# INLINE isGZip #-}

-- | Detects the separator automatically by 
--   checking whether the use of each separator generates
--   the same amount of SRMatrix in every row and at least two SRMatrix.
--
--  >>> detectSep ["x1,x2,x3,x4"] 
-- ','
detectSep :: [B.ByteString] -> Char
detectSep :: [ByteString] -> Char
detectSep [ByteString]
xss = [Char] -> Char
go [Char]
seps
  where
    seps :: [Char]
seps = [Char
' ',Char
'\t',Char
'|',Char
':',Char
';',Char
',']
    xss' :: [ByteString]
xss' = (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
B.strip [ByteString]
xss

    -- consistency check whether all rows have the same
    -- number of columns when spliting by this sep 
    allSameLen :: [a] -> Bool
allSameLen []     = Bool
True
    allSameLen (a
y:[a]
ys) = a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
1 Bool -> Bool -> Bool
&& (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==a
y) [a]
ys

    go :: [Char] -> Char
go []     = [Char] -> Char
forall a. HasCallStack => [Char] -> a
error ([Char] -> Char) -> [Char] -> Char
forall a b. (a -> b) -> a -> b
$ [Char]
"CSV parsing error: unsupported separator. Supporter separators are "
                      [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," ((Char -> [Char]) -> [Char] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Char -> [Char]
forall a. Show a => a -> [Char]
show [Char]
seps)
    go (Char
c:[Char]
cs) = if [Int] -> Bool
forall {a}. (Eq a, Num a) => [a] -> Bool
allSameLen ([Int] -> Bool) -> [Int] -> Bool
forall a b. (a -> b) -> a -> b
$ (ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ByteString] -> Int)
-> (ByteString -> [ByteString]) -> ByteString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
B.split Char
c) [ByteString]
xss'
                   then Char
c
                   else [Char] -> Char
go [Char]
cs
{-# INLINE detectSep #-}

-- | reads a file and returns a list of list of `ByteString`
-- corresponding to each element of the matrix.
-- The first row can be a header. 
readFileToLines :: FilePath -> IO [[B.ByteString]]
readFileToLines :: [Char] -> IO [[ByteString]]
readFileToLines [Char]
filename = do
  content <- [ByteString] -> [ByteString]
removeBEmpty ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
toLines (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
toChar8 (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
unzip (ByteString -> [ByteString]) -> IO ByteString -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
BS.readFile [Char]
filename
  let sep = [ByteString] -> Char
getSep [ByteString]
content
  pure . removeEmpty . map (B.split sep) $ content
  where
      getSep :: [ByteString] -> Char
getSep       = [ByteString] -> Char
detectSep ([ByteString] -> Char)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
100 -- use only first 100 rows to detect separator
      removeBEmpty :: [ByteString] -> [ByteString]
removeBEmpty = (ByteString -> Bool) -> [ByteString] -> [ByteString]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (ByteString -> Bool) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
B.null)
      removeEmpty :: [[a]] -> [[a]]
removeEmpty  = ([a] -> Bool) -> [[a]] -> [[a]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([a] -> Bool) -> [a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null)
      toLines :: ByteString -> [ByteString]
toLines      = Char -> ByteString -> [ByteString]
B.split Char
'\n'
      unzip :: ByteString -> ByteString
unzip        = if [Char] -> Bool
isGZip [Char]
filename then ByteString -> ByteString
decompress else ByteString -> ByteString
forall a. a -> a
id
      toChar8 :: ByteString -> ByteString
toChar8      = [Char] -> ByteString
B.pack ([Char] -> ByteString)
-> (ByteString -> [Char]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Char) -> [Word8] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Char
forall a. Enum a => Int -> a
toEnum (Int -> Char) -> (Word8 -> Int) -> Word8 -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Int
forall a. Enum a => a -> Int
fromEnum) ([Word8] -> [Char])
-> (ByteString -> [Word8]) -> ByteString -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack
{-# INLINE readFileToLines #-}

-- | Splits the parameters from the filename
-- the expected format of the filename is *filename.ext:p1:p2:p3:p4*
-- where p1 and p2 is the starting and end rows for the training data,
-- by default p1 = 0 and p2 = number of rows - 1
-- p3 is the target PVector, it can be a string corresponding to the header
-- or an index.
-- p4 is a comma separated list of SRMatrix (either index or name) to be used as 
-- input variables. These will be renamed internally as x0, x1, ... in the order
-- of this list.
splitFileNameParams :: FilePath -> (FilePath, [B.ByteString])
splitFileNameParams :: [Char] -> ([Char], [ByteString])
splitFileNameParams ([Char] -> ByteString
B.pack -> ByteString
filename) = (ByteString -> [Char]
B.unpack ByteString
fname, Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
4 [ByteString]
params)
  where
    (ByteString
fname : [ByteString]
params') = Char -> ByteString -> [ByteString]
B.split Char
':' ByteString
filename
    -- fill up the empty parameters with an empty string
    params :: [ByteString]
params            = [ByteString]
params' [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> Int -> ByteString -> [ByteString]
forall a. Int -> a -> [a]
replicate (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
4 ([ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
params')) ByteString
B.empty
{-# inline splitFileNameParams #-}

-- | Tries to parse a string into an int
parseVal :: String -> Either String Int
parseVal :: [Char] -> Either [Char] Int
parseVal [Char]
xs = case [Char] -> Maybe Int
forall a. Read a => [Char] -> Maybe a
readMaybe [Char]
xs of
                Maybe Int
Nothing -> [Char] -> Either [Char] Int
forall a b. a -> Either a b
Left [Char]
xs
                Just Int
x  -> Int -> Either [Char] Int
forall a b. b -> Either a b
Right Int
x
{-# inline parseVal #-}

-- | Given a map between PVector name and indeces,
-- the target PVector and the variables SRMatrix,
-- returns the indices of the variables SRMatrix and the target
getColumns :: [(B.ByteString, Int)] -> B.ByteString -> B.ByteString -> ([Int], Int)
getColumns :: [(ByteString, Int)] -> ByteString -> ByteString -> ([Int], Int)
getColumns [(ByteString, Int)]
headerMap ByteString
target ByteString
columns = ([Int]
ixs, Int
iy)
  where
      n_cols :: Int
n_cols  = [(ByteString, Int)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ByteString, Int)]
headerMap
      getIx :: [Char] -> Int
getIx [Char]
c = case [Char] -> Either [Char] Int
parseVal [Char]
c of
                  -- if the PVector is a name, retrive the index
                  Left [Char]
name -> case ((ByteString, Int) -> Bool)
-> [(ByteString, Int)] -> Maybe (ByteString, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== [Char] -> ByteString
B.pack [Char]
name) (ByteString -> Bool)
-> ((ByteString, Int) -> ByteString) -> (ByteString, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Int) -> ByteString
forall a b. (a, b) -> a
fst) [(ByteString, Int)]
headerMap of
                                 Maybe (ByteString, Int)
Nothing -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"PVector name " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
name [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" does not exist."
                                 Just (ByteString, Int)
v  -> (ByteString, Int) -> Int
forall a b. (a, b) -> b
snd (ByteString, Int)
v
                  -- if it is an int, check if it is within range
                  Right Int
v   -> if Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n_cols
                                 then Int
v
                                 else [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"PVector index " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" out of range."
      -- if the input variables SRMatrix are ommitted, use
      -- every PVector except for iy
      ixs :: [Int]
ixs = if ByteString -> Bool
B.null ByteString
columns
               then Int -> [Int] -> [Int]
forall a. Eq a => a -> [a] -> [a]
delete Int
iy [Int
0 .. Int
n_cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
               else (ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> Int
getIx ([Char] -> Int) -> (ByteString -> [Char]) -> ByteString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
B.unpack) ([ByteString] -> [Int]) -> [ByteString] -> [Int]
forall a b. (a -> b) -> a -> b
$ Char -> ByteString -> [ByteString]
B.split Char
',' ByteString
columns
      -- if the target PVector is ommitted, use the last one
      iy :: Int
iy = if ByteString -> Bool
B.null ByteString
target
              then Int
n_cols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
              else [Char] -> Int
getIx ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ ByteString -> [Char]
B.unpack ByteString
target
{-# inline getColumns #-}

-- | Given the start and end rows, it returns the 
-- hmatrix extractors for the training and validation data
getRows :: B.ByteString -> B.ByteString -> Int -> (Int, Int)
getRows :: ByteString -> ByteString -> Int -> (Int, Int)
getRows (ByteString -> [Char]
B.unpack -> [Char]
start) (ByteString -> [Char]
B.unpack -> [Char]
end) Int
nRows
  | Int
st_ix Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
end_ix                 = [Char] -> (Int, Int)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Int, Int)) -> [Char] -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid range: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char] -> [Char]
forall a. Show a => a -> [Char]
show [Char]
start [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
":" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char] -> [Char]
forall a. Show a => a -> [Char]
show [Char]
end [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
  | Int
st_ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
end_ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nRowsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = (Int
0, Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
  | Bool
otherwise                       = (Int
st_ix, Int
end_ix)
  where
      st_ix :: Int
st_ix = if [Char] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
start
                then Int
0
                else case [Char] -> Maybe Int
forall a. Read a => [Char] -> Maybe a
readMaybe [Char]
start of
                       Maybe Int
Nothing -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid starting row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
start [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                       Just Int
x  -> if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nRows
                                    then [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid starting row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                                    else Int
x
      end_ix :: Int
end_ix = if [Char] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
end
                then Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
                else case [Char] -> Maybe Int
forall a. Read a => [Char] -> Maybe a
readMaybe [Char]
end of
                       Maybe Int
Nothing -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid end row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
end [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                       Just Int
x  -> if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
nRows
                                    then [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int) -> [Char] -> Int
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid end row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                                    else Int
x
{-# inline getRows #-}

-- | `loadDataset` loads a dataset with a filename in the format:
--   filename.ext:start_row:end_row:target:features
--   it returns the X_train, y_train, X_test, y_test, varnames, target name 
--   where varnames are a comma separated list of the name of the vars 
--   and target name is the name of the target
--
-- where
--
-- **start_row:end_row** is the range of the training rows (default 0:nrows-1).
--   every other row not included in this range will be used as validation
-- **target** is either the name of the PVector (if the datafile has headers) or the index
-- of the target variable
-- **features** is a comma separated list of SRMatrix names or indices to be used as
-- input variables of the regression model.
loadDataset :: FilePath -> Bool -> IO ((SRMatrix, PVector, SRMatrix, PVector), String, String)
loadDataset :: [Char]
-> Bool
-> IO
     ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
      [Char], [Char])
loadDataset [Char]
filename Bool
hasHeader = do  
  csv <- [Char] -> IO [[ByteString]]
readFileToLines [Char]
fname
  pure $ processData csv params hasHeader
  where
    ([Char]
fname, [ByteString]
params) = [Char] -> ([Char], [ByteString])
splitFileNameParams [Char]
filename

-- support function that does everything for loadDataset
processData :: [[B.ByteString]] -> [B.ByteString] -> Bool -> ((SRMatrix, PVector, SRMatrix, PVector), String, String)
processData :: [[ByteString]]
-> [ByteString]
-> Bool
-> ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
    [Char], [Char])
processData [[ByteString]]
csv [ByteString]
params Bool
hasHeader = ((Array S Ix2 Double
x_train, PVector
y_train, Array S Ix2 Double
x_val, PVector
y_val) , [Char]
varnames, [Char]
targetname)
  where
    ncols :: Int
ncols             = [ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ByteString] -> Int) -> [ByteString] -> Int
forall a b. (a -> b) -> a -> b
$ [[ByteString]] -> [ByteString]
forall a. HasCallStack => [a] -> a
head [[ByteString]]
csv
    nrows :: Int
nrows             = [[ByteString]] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[ByteString]]
csv Int -> Int -> Int
forall a. Num a => a -> a -> a
- Bool -> Int
forall a. Enum a => a -> Int
fromEnum Bool
hasHeader
    ([(ByteString, Int)]
header, [[ByteString]]
content) = if Bool
hasHeader
                           then ([ByteString] -> [Int] -> [(ByteString, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
B.strip ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ [[ByteString]] -> [ByteString]
forall a. HasCallStack => [a] -> a
head [[ByteString]]
csv) [Int
0..], [[ByteString]] -> [[ByteString]]
forall a. HasCallStack => [a] -> [a]
tail [[ByteString]]
csv)
                           else ((Int -> (ByteString, Int)) -> [Int] -> [(ByteString, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> ([Char] -> ByteString
B.pack (Char
'x' Char -> [Char] -> [Char]
forall a. a -> [a] -> [a]
: Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i), Int
i)) [Int
0 .. Int
ncolsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1], [[ByteString]]
csv)
    varnames :: [Char]
varnames          = [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," [ByteString -> [Char]
B.unpack ByteString
v | Int
c <- [Int]
ixs
                                        , let v :: ByteString
v = (ByteString, Int) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Int) -> ByteString)
-> (Maybe (ByteString, Int) -> (ByteString, Int))
-> Maybe (ByteString, Int)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (ByteString, Int) -> (ByteString, Int)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ByteString, Int) -> ByteString)
-> Maybe (ByteString, Int) -> ByteString
forall a b. (a -> b) -> a -> b
$ ((ByteString, Int) -> Bool)
-> [(ByteString, Int)] -> Maybe (ByteString, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
c)(Int -> Bool)
-> ((ByteString, Int) -> Int) -> (ByteString, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(ByteString, Int) -> Int
forall a b. (a, b) -> b
snd) [(ByteString, Int)]
header
                                        ]
    targetname :: [Char]
targetname        = if Bool
hasHeader then (ByteString -> [Char]
B.unpack (ByteString -> [Char])
-> ([(ByteString, Int)] -> ByteString)
-> [(ByteString, Int)]
-> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Int) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Int) -> ByteString)
-> ([(ByteString, Int)] -> (ByteString, Int))
-> [(ByteString, Int)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (ByteString, Int) -> (ByteString, Int)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ByteString, Int) -> (ByteString, Int))
-> ([(ByteString, Int)] -> Maybe (ByteString, Int))
-> [(ByteString, Int)]
-> (ByteString, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, Int) -> Bool)
-> [(ByteString, Int)] -> Maybe (ByteString, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
iy)(Int -> Bool)
-> ((ByteString, Int) -> Int) -> (ByteString, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(ByteString, Int) -> Int
forall a b. (a, b) -> b
snd) ([(ByteString, Int)] -> [Char]) -> [(ByteString, Int)] -> [Char]
forall a b. (a -> b) -> a -> b
$ [(ByteString, Int)]
header) else [Char]
"y"
    -- get rows and SRMatrix indices
    (Int
st, Int
end) = ByteString -> ByteString -> Int -> (Int, Int)
getRows ([ByteString]
params [ByteString] -> Int -> ByteString
forall a. HasCallStack => [a] -> Int -> a
!! Int
0) ([ByteString]
params [ByteString] -> Int -> ByteString
forall a. HasCallStack => [a] -> Int -> a
!! Int
1) Int
nrows
    ([Int]
ixs, Int
iy) = [(ByteString, Int)] -> ByteString -> ByteString -> ([Int], Int)
getColumns [(ByteString, Int)]
header ([ByteString]
params [ByteString] -> Int -> ByteString
forall a. HasCallStack => [a] -> Int -> a
!! Int
2) ([ByteString]
params [ByteString] -> Int -> ByteString
forall a. HasCallStack => [a] -> Int -> a
!! Int
3)

    -- load data and split sets
    datum :: Array S Ix2 Double
datum   = [[ByteString]] -> Array S Ix2 Double
loadMtx [[ByteString]]
content
    p :: Int
p       = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
ixs

    x :: Array S Ix2 Double
x       = S -> Array DL Ix2 Double -> Array S Ix2 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Ix2 Double -> Array S Ix2 Double)
-> Array DL Ix2 Double -> Array S Ix2 Double
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double)
-> Either SomeException (Array DL Ix2 Double)
-> Array DL Ix2 Double
forall a b. (a -> b) -> a -> b
$ [Array D (Lower Ix2) Double]
-> Either SomeException (Array DL Ix2 Double)
forall r ix e (f :: * -> *) (m :: * -> *).
(Foldable f, MonadThrow m, Index (Lower ix), Source r e,
 Index ix) =>
f (Array r (Lower ix) e) -> m (Array DL ix e)
M.stackInnerSlicesM ([Array D (Lower Ix2) Double]
 -> Either SomeException (Array DL Ix2 Double))
-> [Array D (Lower Ix2) Double]
-> Either SomeException (Array DL Ix2 Double)
forall a b. (a -> b) -> a -> b
$ (Int -> Array D (Lower Ix2) Double)
-> [Int] -> [Array D (Lower Ix2) Double]
forall a b. (a -> b) -> [a] -> [b]
map (Array S Ix2 Double
datum Array S Ix2 Double -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<!) [Int]
ixs
    y :: Array D (Lower Ix2) Double
y       = Array S Ix2 Double
datum Array S Ix2 Double -> Int -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Int -> Array D (Lower ix) e
<! Int
iy
    x_train :: Array S Ix2 Double
x_train = S -> Array D Ix2 Double -> Array S Ix2 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix2 Double -> Array S Ix2 Double)
-> Array D Ix2 Double -> Array S Ix2 Double
forall a b. (a -> b) -> a -> b
$ Ix2 -> Ix2 -> Array S Ix2 Double -> Array D Ix2 Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
ix -> ix -> Array r ix e -> Array D ix e
M.extractFromTo' (Int
st Int -> Int -> Ix2
:. Int
0) (Int
endInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 Int -> Int -> Ix2
:. Int
p) Array S Ix2 Double
x
    y_train :: PVector
y_train = S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Int Double -> PVector) -> Array D Int Double -> PVector
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Array D Int Double -> Array D Int Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
ix -> ix -> Array r ix e -> Array D ix e
M.extractFromTo' Int
st (Int
endInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Array D Int Double
y 
    x_val :: Array S Ix2 Double
x_val   = S -> Array DL Ix2 Double -> Array S Ix2 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Ix2 Double -> Array S Ix2 Double)
-> Array DL Ix2 Double -> Array S Ix2 Double
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double)
-> Either SomeException (Array DL Ix2 Double)
-> Array DL Ix2 Double
forall a b. (a -> b) -> a -> b
$ Int
-> Sz Int
-> Array S Ix2 Double
-> Either SomeException (Array DL Ix2 Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Index (Lower ix), Source r e) =>
Int -> Sz Int -> Array r ix e -> m (Array DL ix e)
M.deleteRowsM Int
st (Int -> Sz Int
Sz1 (Int -> Sz Int) -> Int -> Sz Int
forall a b. (a -> b) -> a -> b
$ Int
end Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
st Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Array S Ix2 Double
x
    y_val :: PVector
y_val   = S -> Array DL Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Int Double -> PVector) -> Array DL Int Double -> PVector
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Int Double) -> Array DL Int Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Int Double) -> Array DL Int Double)
-> Either SomeException (Array DL Int Double)
-> Array DL Int Double
forall a b. (a -> b) -> a -> b
$ Int
-> Sz Int
-> Array D Int Double
-> Either SomeException (Array DL Int Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
Int -> Sz Int -> Array r ix e -> m (Array DL ix e)
M.deleteColumnsM Int
st (Int -> Sz Int
Sz1 (Int -> Sz Int) -> Int -> Sz Int
forall a b. (a -> b) -> a -> b
$ Int
end Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
st Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Array D Int Double
y
{-# inline processData #-}