HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/current/platform/gsutil/gslib/tests/test_daisy_chain_wrapper.py
# -*- coding: utf-8 -*-
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for daisy chain wrapper class."""

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
from __future__ import unicode_literals

import math
import os
import pkgutil

import six
import gslib.cloud_api
from gslib.daisy_chain_wrapper import DaisyChainWrapper
from gslib.storage_url import StorageUrlFromString
import gslib.tests.testcase as testcase
from gslib.utils.constants import TRANSFER_BUFFER_SIZE

_TEST_FILE = 'test.txt'


class TestDaisyChainWrapper(testcase.GsUtilUnitTestCase):
  """Unit tests for the DaisyChainWrapper class."""

  _temp_test_file = None
  _dummy_url = StorageUrlFromString('gs://bucket/object')

  def setUp(self):
    super(TestDaisyChainWrapper, self).setUp()
    self.test_data_file = self._GetTestFile()
    self.test_data_file_len = os.path.getsize(self.test_data_file)

  def _GetTestFile(self):
    contents = pkgutil.get_data('gslib', 'tests/test_data/%s' % _TEST_FILE)
    if not self._temp_test_file:
      # Write to a temp file because pkgutil doesn't expose a stream interface.
      self._temp_test_file = self.CreateTempFile(file_name=_TEST_FILE,
                                                 contents=contents)
    return self._temp_test_file

  class MockDownloadCloudApi(gslib.cloud_api.CloudApi):
    """Mock CloudApi that implements GetObjectMedia for testing."""

    def __init__(self, write_values):
      """Initialize the mock that will be used by the download thread.

      Args:
        write_values: List of values that will be used for calls to write(),
            in order, by the download thread. An Exception class may be part of
            the list; if so, the Exception will be raised after previous
            values are consumed.
      """
      self._write_values = write_values
      self.get_calls = 0

    def GetObjectMedia(self,
                       unused_bucket_name,
                       unused_object_name,
                       download_stream,
                       start_byte=0,
                       end_byte=None,
                       **kwargs):
      """Writes self._write_values to the download_stream."""
      # Writes from start_byte up to, but not including end_byte (if not None).
      # Does not slice values;
      # self._write_values must line up with start/end_byte.
      self.get_calls += 1
      bytes_read = 0
      for write_value in self._write_values:
        if bytes_read < start_byte:
          bytes_read += len(write_value)
          continue
        if end_byte and bytes_read >= end_byte:
          break
        if isinstance(write_value, Exception):
          raise write_value
        download_stream.write(write_value)
        bytes_read += len(write_value)

  def _WriteFromWrapperToFile(self, daisy_chain_wrapper, file_path):
    """Writes all contents from the DaisyChainWrapper to the named file."""
    with open(file_path, 'wb') as upload_stream:
      while True:
        data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
        if not data:
          break
        upload_stream.write(data)

  def testDownloadSingleChunk(self):
    """Tests a single call to GetObjectMedia."""
    write_values = []
    with open(self.test_data_file, 'rb') as stream:
      while True:
        data = stream.read(TRANSFER_BUFFER_SIZE)
        if not data:
          break
        write_values.append(data)
    upload_file = self.CreateTempFile()
    # Test for a single call even if the chunk size is larger than the data.
    for chunk_size in (self.test_data_file_len, self.test_data_file_len + 1):
      mock_api = self.MockDownloadCloudApi(write_values)
      daisy_chain_wrapper = DaisyChainWrapper(self._dummy_url,
                                              self.test_data_file_len,
                                              mock_api,
                                              download_chunk_size=chunk_size)
      self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
      # Since the chunk size is >= the file size, only a single GetObjectMedia
      # call should be made.
      self.assertEqual(mock_api.get_calls, 1)
      with open(upload_file, 'rb') as upload_stream:
        with open(self.test_data_file, 'rb') as download_stream:
          self.assertEqual(upload_stream.read(), download_stream.read())

  def testDownloadMultiChunk(self):
    """Tests multiple calls to GetObjectMedia."""
    upload_file = self.CreateTempFile()
    write_values = []
    with open(self.test_data_file, 'rb') as stream:
      while True:
        data = stream.read(TRANSFER_BUFFER_SIZE)
        if not data:
          break
        write_values.append(data)
    mock_api = self.MockDownloadCloudApi(write_values)
    daisy_chain_wrapper = DaisyChainWrapper(
        self._dummy_url,
        self.test_data_file_len,
        mock_api,
        download_chunk_size=TRANSFER_BUFFER_SIZE)
    self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
    num_expected_calls = self.test_data_file_len // TRANSFER_BUFFER_SIZE
    if self.test_data_file_len % TRANSFER_BUFFER_SIZE:
      num_expected_calls += 1
    # Since the chunk size is < the file size, multiple calls to GetObjectMedia
    # should be made.
    self.assertEqual(mock_api.get_calls, num_expected_calls)
    with open(upload_file, 'rb') as upload_stream:
      with open(self.test_data_file, 'rb') as download_stream:
        self.assertEqual(upload_stream.read(), download_stream.read())

  def testDownloadWithDifferentChunkSize(self):
    """Tests multiple calls to GetObjectMedia."""
    upload_file = self.CreateTempFile()
    write_values = []
    with open(self.test_data_file, 'rb') as stream:
      # Use an arbitrary size greater than TRANSFER_BUFFER_SIZE for writing
      # data to the buffer. For reading from the buffer
      # WriteFromWrapperToFile will use TRANSFER_BUFFER_SIZE.
      buffer_write_size = TRANSFER_BUFFER_SIZE * 2 + 10
      while True:
        # Write data with size.
        data = stream.read(buffer_write_size)
        if not data:
          break
        write_values.append(data)
    mock_api = self.MockDownloadCloudApi(write_values)
    daisy_chain_wrapper = DaisyChainWrapper(
        self._dummy_url,
        self.test_data_file_len,
        mock_api,
        download_chunk_size=TRANSFER_BUFFER_SIZE)
    self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
    num_expected_calls = math.ceil(self.test_data_file_len /
                                   TRANSFER_BUFFER_SIZE)
    # Since the chunk size is < the file size, multiple calls to GetObjectMedia
    # should be made.
    self.assertEqual(mock_api.get_calls, num_expected_calls)
    with open(upload_file, 'rb') as upload_stream:
      with open(self.test_data_file, 'rb') as download_stream:
        self.assertEqual(upload_stream.read(), download_stream.read())

  def testDownloadWithZeroWrites(self):
    """Tests 0-byte writes to the download stream from GetObjectMedia."""
    write_values = []
    with open(self.test_data_file, 'rb') as stream:
      while True:
        write_values.append(b'')
        data = stream.read(TRANSFER_BUFFER_SIZE)
        write_values.append(b'')
        if not data:
          break
        write_values.append(data)
    upload_file = self.CreateTempFile()
    mock_api = self.MockDownloadCloudApi(write_values)
    daisy_chain_wrapper = DaisyChainWrapper(
        self._dummy_url,
        self.test_data_file_len,
        mock_api,
        download_chunk_size=self.test_data_file_len)
    self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
    self.assertEqual(mock_api.get_calls, 1)
    with open(upload_file, 'rb') as upload_stream:
      with open(self.test_data_file, 'rb') as download_stream:
        self.assertEqual(upload_stream.read(), download_stream.read())

  def testDownloadWithPartialWrite(self):
    """Tests unaligned writes to the download stream from GetObjectMedia."""
    with open(self.test_data_file, 'rb') as stream:
      chunk = stream.read(TRANSFER_BUFFER_SIZE)
    # Though it may seem equivalent, the `:1` is actually necessary, without
    # it in python 3, `one_byte` would be int(77) and with it, `one_byte` is
    # the expected value of b'M' (using case where start of chunk is b'MJoTM...')
    one_byte = chunk[0:1]
    chunk_minus_one_byte = chunk[1:TRANSFER_BUFFER_SIZE]
    half_chunk = chunk[0:TRANSFER_BUFFER_SIZE // 2]

    write_values_dict = {
        'First byte first chunk unaligned':
            (one_byte, chunk_minus_one_byte, chunk, chunk),
        'Last byte first chunk unaligned': (chunk_minus_one_byte, chunk, chunk),
        'First byte second chunk unaligned':
            (chunk, one_byte, chunk_minus_one_byte, chunk),
        'Last byte second chunk unaligned':
            (chunk, chunk_minus_one_byte, one_byte, chunk),
        'First byte final chunk unaligned':
            (chunk, chunk, one_byte, chunk_minus_one_byte),
        'Last byte final chunk unaligned':
            (chunk, chunk, chunk_minus_one_byte, one_byte),
        'Half chunks': (half_chunk, half_chunk, half_chunk),
        'Many unaligned':
            (one_byte, half_chunk, one_byte, half_chunk, chunk,
             chunk_minus_one_byte, chunk, one_byte, half_chunk, one_byte)
    }
    upload_file = self.CreateTempFile()
    for case_name, write_values in six.iteritems(write_values_dict):
      expected_contents = b''
      for write_value in write_values:
        expected_contents += write_value
      mock_api = self.MockDownloadCloudApi(write_values)
      daisy_chain_wrapper = DaisyChainWrapper(
          self._dummy_url,
          len(expected_contents),
          mock_api,
          download_chunk_size=self.test_data_file_len)
      self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
      with open(upload_file, 'rb') as upload_stream:
        self.assertEqual(
            upload_stream.read(), expected_contents,
            'Uploaded file contents for case %s did not match' % case_name)

  def testSeekAndReturn(self):
    """Tests seeking to the end of the wrapper (simulates getting size)."""
    write_values = []
    with open(self.test_data_file, 'rb') as stream:
      while True:
        data = stream.read(TRANSFER_BUFFER_SIZE)
        if not data:
          break
        write_values.append(data)
    upload_file = self.CreateTempFile()
    mock_api = self.MockDownloadCloudApi(write_values)
    daisy_chain_wrapper = DaisyChainWrapper(
        self._dummy_url,
        self.test_data_file_len,
        mock_api,
        download_chunk_size=self.test_data_file_len)
    with open(upload_file, 'wb') as upload_stream:
      current_position = 0
      daisy_chain_wrapper.seek(0, whence=os.SEEK_END)
      daisy_chain_wrapper.seek(current_position)
      while True:
        data = daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
        current_position += len(data)
        daisy_chain_wrapper.seek(0, whence=os.SEEK_END)
        daisy_chain_wrapper.seek(current_position)
        if not data:
          break
        upload_stream.write(data)
    self.assertEqual(mock_api.get_calls, 1)
    with open(upload_file, 'rb') as upload_stream:
      with open(self.test_data_file, 'rb') as download_stream:
        self.assertEqual(upload_stream.read(), download_stream.read())

  def testRestartDownloadThread(self):
    """Tests seek to non-stored position; this restarts the download thread."""
    write_values = []
    with open(self.test_data_file, 'rb') as stream:
      while True:
        data = stream.read(TRANSFER_BUFFER_SIZE)
        if not data:
          break
        write_values.append(data)
    upload_file = self.CreateTempFile()
    mock_api = self.MockDownloadCloudApi(write_values)
    daisy_chain_wrapper = DaisyChainWrapper(
        self._dummy_url,
        self.test_data_file_len,
        mock_api,
        download_chunk_size=self.test_data_file_len)
    daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
    daisy_chain_wrapper.read(TRANSFER_BUFFER_SIZE)
    daisy_chain_wrapper.seek(0)
    self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
    self.assertEqual(mock_api.get_calls, 2)
    with open(upload_file, 'rb') as upload_stream:
      with open(self.test_data_file, 'rb') as download_stream:
        self.assertEqual(upload_stream.read(), download_stream.read())

  def testDownloadThreadException(self):
    """Tests that an exception is propagated via the upload thread."""

    class DownloadException(Exception):
      pass

    write_values = [
        b'a', b'b',
        DownloadException('Download thread forces failure')
    ]
    upload_file = self.CreateTempFile()
    mock_api = self.MockDownloadCloudApi(write_values)
    daisy_chain_wrapper = DaisyChainWrapper(
        self._dummy_url,
        self.test_data_file_len,
        mock_api,
        download_chunk_size=self.test_data_file_len)
    try:
      self._WriteFromWrapperToFile(daisy_chain_wrapper, upload_file)
      self.fail('Expected exception')
    except DownloadException as e:
      self.assertIn('Download thread forces failure', str(e))

  def testInvalidSeek(self):
    """Tests that seeking fails for unsupported seek arguments."""
    daisy_chain_wrapper = DaisyChainWrapper(self._dummy_url,
                                            self.test_data_file_len,
                                            self.MockDownloadCloudApi([]))
    try:
      # SEEK_CUR is invalid.
      daisy_chain_wrapper.seek(0, whence=os.SEEK_CUR)
      self.fail('Expected exception')
    except IOError as e:
      self.assertIn('does not support seek mode', str(e))

    try:
      # Seeking from the end with an offset is invalid.
      daisy_chain_wrapper.seek(1, whence=os.SEEK_END)
      self.fail('Expected exception')
    except IOError as e:
      self.assertIn('Invalid seek during daisy chain', str(e))