File: //snap/google-cloud-cli/current/lib/third_party/ml_sdk/cloud/ml/util/_decoders.py
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataflow-related utilities.
"""
import csv
import json
import logging
class DecodeError(Exception):
"""Base decode error."""
pass
class PassthroughDecoder(object):
def decode(self, x):
return x
class JsonDecoder(object):
"""A decoder for JSON formatted data."""
def decode(self, x):
return json.loads(x)
class CsvDecoder(object):
"""A decoder for CSV formatted data.
"""
# TODO(user) Revisit using cStringIO for design compatibility with
# coders.CsvCoder.
class _LineGenerator(object):
"""A csv line generator that allows feeding lines to a csv.DictReader."""
def __init__(self):
self._lines = []
def push_line(self, line):
# This API currently supports only one line at a time.
assert not self._lines
self._lines.append(line)
def __iter__(self):
return self
def next(self):
# This API currently supports only one line at a time.
# If this ever supports more than one row be aware that DictReader might
# attempt to read more than one record if one of the records is empty line
line_length = len(self._lines)
if line_length == 0:
raise DecodeError(
'Columns do not match specified csv headers: empty line was found')
assert line_length == 1, 'Unexpected number of lines %s' % line_length
# This doesn't maintain insertion order to the list, which is fine
# because the list has only 1 element. If there were more and we wanted
# to maintain order and timecomplexity we would switch to deque.popleft.
return self._lines.pop()
class _ReaderWrapper(object):
"""A wrapper for csv.reader / csv.DictReader to make it picklable."""
def __init__(self, line_generator, column_names, delimiter, decode_to_dict,
skip_initial_space):
self._state = (line_generator, column_names, delimiter, decode_to_dict,
skip_initial_space)
self._line_generator = line_generator
if decode_to_dict:
self._reader = csv.DictReader(
line_generator, column_names, delimiter=str(delimiter),
skipinitialspace=skip_initial_space)
else:
self._reader = csv.reader(line_generator, delimiter=str(delimiter),
skipinitialspace=skip_initial_space)
def read_record(self, x):
self._line_generator.push_line(x)
return self._reader.next()
def __getstate__(self):
return self._state
def __setstate__(self, state):
self.__init__(*state)
def __init__(
self, column_names, numeric_column_names, delimiter, decode_to_dict,
fail_on_error, skip_initial_space):
"""Initializer.
Args:
column_names: Tuple of strings. Order must match the order in the file.
numeric_column_names: Tuple of strings. Contains column names that are
numeric. Every name in numeric_column_names must also be in
column_names.
delimiter: String used to separate fields.
decode_to_dict: Boolean indicating whether the docoder should generate a
dictionary instead of a raw sequence. True by default.
fail_on_error: Whether to fail if a corrupt row is found.
skip_initial_space: When True, whitespace immediately following the
delimiter is ignored.
"""
self._column_names = column_names
self._numeric_column_names = set(numeric_column_names)
self._reader = self._ReaderWrapper(
self._LineGenerator(), column_names, delimiter, decode_to_dict,
skip_initial_space)
self._decode_to_dict = decode_to_dict
self._fail_on_error = fail_on_error
def _handle_corrupt_row(self, message):
"""Handle corrupt rows.
Depending on whether the decoder is configured to fail on error it will
raise a DecodeError or return None.
Args:
message: String, the error message to raise.
Returns:
None, when the decoder is not configured to fail on error.
Raises:
DecodeError: when the decoder is configured to fail on error.
"""
if self._fail_on_error:
raise DecodeError(message)
else:
# TODO(user) Don't log every time but only every N invalid lines.
logging.warning('Discarding invalid row: %s', message)
return None
def _get_value(self, column_name, value):
# TODO(user) remove this logic from the decoders and let it be
# part of prepreocessing. CSV is a schema-less container we shouldn't be
# performing these conversions here.
if not value or not value.strip():
return None
if column_name in self._numeric_column_names:
return float(value)
return value
# Please run //third_party/py/google/cloud/ml:benchmark_coders_test
# if you make any changes on these methods.
def decode(self, record):
"""Decodes the given string.
Args:
record: String to be decoded.
Returns:
Serialized object corresponding to decoded string. Or None if there's an
error and the decoder is configured not to fail on error.
Raises:
DecodeError: If columns do not match specified csv headers.
ValueError: If some numeric column has non-numeric data.
"""
try:
record = self._reader.read_record(record)
except Exception as e: # pylint: disable=broad-except
return self._handle_corrupt_row('%s: %s' % (e, record))
# Check record length mismatches.
if len(record) != len(self._column_names):
return self._handle_corrupt_row(
'Columns do not match specified csv headers: %s -> %s' % (
self._column_names, record))
if self._decode_to_dict:
# DictReader fills missing colums with None. Thus, if the last value
# as defined by the schema is None, there was at least one "missing"
# column.
if record[self._column_names[-1]] is None:
return self._handle_corrupt_row(
'Columns do not match specified csv headers: %s -> %s' % (
self._column_names, record))
for name, value in record.iteritems():
record[name] = self._get_value(name, value)
else:
for index, name in enumerate(self._column_names):
value = record[index]
record[index] = self._get_value(name, value)
return record