HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/datastream/util.py
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream API utilities."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

import uuid

from apitools.base.py import encoding as api_encoding
from googlecloudsdk.api_lib.dataproc import exceptions
from googlecloudsdk.api_lib.datastream import camel_case_utils
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.export import util as export_util
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
import six


_DEFAULT_API_VERSION = 'v1'
_DEFAULT_API_NAME = 'datastream'
CAMEL_CASE_CONVERSION_EVENT = _DEFAULT_API_NAME + '_camel_case_conversion'

# TODO(b/207467120): remove translation after BETA deprecation.
_UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1 = {
    'allowlist': 'include_objects',
    'rejectlist': 'exclude_objects',
    'source_connection_profile_name': 'source_connection_profile',
    'destination_connection_profile_name': 'destination_connection_profile',
    'vpc_name': 'vpc',
}
RDBMS_FIELD_NAME_BY_RELEASE_TRACK = {
    'schema': {
        base.ReleaseTrack.BETA: 'schema_name',
        base.ReleaseTrack.GA: 'schema'
    },
    'database': {
        base.ReleaseTrack.BETA: 'database_name',
        base.ReleaseTrack.GA: 'database'
    },
    'table': {
        base.ReleaseTrack.BETA: 'table_name',
        base.ReleaseTrack.GA: 'table'
    },
    'column': {
        base.ReleaseTrack.BETA: 'column_name',
        base.ReleaseTrack.GA: 'column'
    },
    'include_objects': {
        base.ReleaseTrack.BETA: 'allowlist',
        base.ReleaseTrack.GA: 'include_objects'
    },
    'exclude_objects': {
        base.ReleaseTrack.BETA: 'rejectlist',
        base.ReleaseTrack.GA: 'exclude_objects'
    }
}


def ParseMessageAndValidateSchema(config_file_path, schema_name, message_type):
  """Parses a config message and validates it's schema."""
  schema_path = export_util.GetSchemaPath(
      _DEFAULT_API_NAME, _DEFAULT_API_VERSION, schema_name, for_help=False
  )

  # NOMUTANTS -- not necessary here.
  data = console_io.ReadFromFileOrStdin(config_file_path, binary=False)
  parsed_yaml = yaml.load(data)

  message = CreateMessageWithCamelCaseConversion(
      message_type=message_type,
      parsed_yaml=parsed_yaml,
      schema_path=schema_path,
  )

  return message


def GetClientInstance(api_version=_DEFAULT_API_VERSION, no_http=False):
  return apis.GetClientInstance('datastream', api_version, no_http=no_http)


def GetMessagesModule(api_version=_DEFAULT_API_VERSION):
  return apis.GetMessagesModule('datastream', api_version)


def GetResourceParser(api_version=_DEFAULT_API_VERSION):
  resource_parser = resources.Registry()
  resource_parser.RegisterApiByName('datastream', api_version)
  return resource_parser


def ParentRef(project, location):
  """Get the resource name of the parent collection.

  Args:
    project: the project of the parent collection.
    location: the GCP region of the membership.

  Returns:
    the resource name of the parent collection in the format of
    `projects/{project}/locations/{location}`.
  """

  return 'projects/{}/locations/{}'.format(project, location)


def GenerateRequestId():
  """Generates a UUID to use as the request ID.

  Returns:
    string, the 40-character UUID for the request ID.
  """
  return six.text_type(uuid.uuid4())


def ParseMysqlRdbmsFile(
    messages, mysql_rdbms_file, release_track=base.ReleaseTrack.BETA
):
  """Parses a mysql_rdbms_file into the MysqlRdbms message."""
  if release_track != base.ReleaseTrack.BETA:
    return ParseMessageAndValidateSchema(
        mysql_rdbms_file, 'MysqlRdbms', messages.MysqlRdbms
    )

  return ParseMysqlRdbmsFileBeta(messages, mysql_rdbms_file, release_track)


def ParseOracleRdbmsFile(
    messages, oracle_rdbms_file, release_track=base.ReleaseTrack.BETA
):
  """Parses a oracle_rdbms_file into the OracleRdbms message."""
  if release_track != base.ReleaseTrack.BETA:
    return ParseMessageAndValidateSchema(
        oracle_rdbms_file, 'OracleRdbms', messages.OracleRdbms
    )

  return ParseOracleRdbmsFileBeta(messages, oracle_rdbms_file, release_track)


def ParsePostgresqlRdbmsFile(messages, postgresql_rdbms_file):
  """Parses a postgresql_rdbms_file into the PostgresqlRdbms message."""
  return ParseMessageAndValidateSchema(
      postgresql_rdbms_file, 'PostgresqlRdbms', messages.PostgresqlRdbms
  )


def ParseSqlServerRdbmsFile(messages, sqlserver_rdbms_file):
  """Parses a sqlserver_rdbms_file into the SqlServerRdbms message."""
  return ParseMessageAndValidateSchema(
      sqlserver_rdbms_file, 'SqlServerRdbms', messages.SqlServerRdbms
  )


def ParseSalesforceOrgFile(messages, salesforce_org_file):
  """Parses a salesforce_org_file into the SalesforceOrg message."""
  return ParseMessageAndValidateSchema(
      salesforce_org_file, 'SalesforceOrg', messages.SalesforceOrg
  )


def ParseMongodbFile(messages, mongodb_file):
  """Parses a mongodb_file into the MongodbCluster message."""
  return ParseMessageAndValidateSchema(
      mongodb_file, 'MongodbCluster', messages.MongodbCluster
  )


def CreateMessageWithCamelCaseConversion(
    message_type, parsed_yaml, schema_path=None
):
  """Create a message from a yaml dict.

  Similar to export_util.Import (since we convert to camel case before)
  Args:
    message_type: a Datastream message type to create.
    parsed_yaml: dict
    schema_path: str, path to the message schema to validate against.

  Returns:
    a Datastream message.
  """
  converted_yaml = camel_case_utils.ConvertYamlToCamelCase(parsed_yaml)
  if schema_path:
    # If a schema is provided, validate against it.
    export_util.ValidateYAML(converted_yaml, schema_path)
  try:
    message = api_encoding.PyValueToMessage(message_type, converted_yaml)
  except Exception as e:
    raise exceptions.ParseError('Cannot parse YAML: [{0}]'.format(e))
  return message


# TODO(b/207467120): deprecate BETA client.
def GetRDBMSV1alpha1ToV1FieldName(field, release_track):
  return RDBMS_FIELD_NAME_BY_RELEASE_TRACK.get(field, {}).get(
      release_track, field
  )


def _GetRDBMSFieldName(field, release_track):
  return RDBMS_FIELD_NAME_BY_RELEASE_TRACK.get(field, {}).get(
      release_track, field
  )


# Deprecated BETA methods - TODO(b/207467120).
# remove after full BETA deprecation.
def ParseMysqlColumn(messages, mysql_column_object, release_track):
  """Parses a raw mysql column json/yaml into the MysqlColumn message."""
  message = messages.MysqlColumn(
      column=mysql_column_object.get(
          _GetRDBMSFieldName('column', release_track), ''))
  data_type = mysql_column_object.get('data_type')
  if data_type is not None:
    message.dataType = data_type
  collation = mysql_column_object.get('collation')
  if collation is not None:
    message.collation = collation
  length = mysql_column_object.get('length')
  if length is not None:
    message.length = length
  nullable = mysql_column_object.get('nullable')
  if nullable is not None:
    message.nullable = nullable
  ordinal_position = mysql_column_object.get('ordinal_position')
  if ordinal_position is not None:
    message.ordinalPosition = ordinal_position
  primary_key = mysql_column_object.get('primary_key')
  if primary_key is not None:
    message.primaryKey = primary_key
  return message


def ParseMysqlTable(messages, mysql_table_object, release_track):
  """Parses a raw mysql table json/yaml into the MysqlTable message."""
  mysql_column_msg_list = []
  for column in mysql_table_object.get('mysql_columns', []):
    mysql_column_msg_list.append(
        ParseMysqlColumn(messages, column, release_track))
  table_key = _GetRDBMSFieldName('table', release_track)
  table_name = mysql_table_object.get(table_key)
  if not table_name:
    raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
                                   table_key)
  return messages.MysqlTable(
      table=table_name, mysqlColumns=mysql_column_msg_list)


def ParseMysqlDatabase(messages, mysql_database_object, release_track):
  """Parses a raw mysql database json/yaml into the MysqlDatabase message."""
  mysql_tables_msg_list = []
  for table in mysql_database_object.get('mysql_tables', []):
    mysql_tables_msg_list.append(
        ParseMysqlTable(messages, table, release_track))
  database_key = _GetRDBMSFieldName('database', release_track)
  database_name = mysql_database_object.get(database_key)
  if not database_name:
    raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
                                   database_key)
  return messages.MysqlDatabase(
      database=database_name, mysqlTables=mysql_tables_msg_list)


def ParseMysqlSchemasListToMysqlRdbmsMessage(messages,
                                             mysql_rdbms_data,
                                             release_track=base.ReleaseTrack
                                             .BETA):
  """Parses an object of type {mysql_databases: [...]} into the MysqlRdbms message."""
  mysql_databases_raw = mysql_rdbms_data.get('mysql_databases', [])
  mysql_database_msg_list = []
  for schema in mysql_databases_raw:
    mysql_database_msg_list.append(
        ParseMysqlDatabase(messages, schema, release_track))

  mysql_rdbms_msg = messages.MysqlRdbms(
      mysqlDatabases=mysql_database_msg_list)
  return mysql_rdbms_msg


def ParseOracleColumn(messages, oracle_column_object, release_track):
  """Parses a raw oracle column json/yaml into the OracleColumn message."""
  message = messages.OracleColumn(
      column=oracle_column_object.get(
          _GetRDBMSFieldName('column', release_track), ''))
  data_type = oracle_column_object.get('data_type')
  if data_type is not None:
    message.dataType = data_type
  encoding = oracle_column_object.get('encoding')
  if encoding is not None:
    message.encoding = encoding
  length = oracle_column_object.get('length')
  if length is not None:
    message.length = length
  nullable = oracle_column_object.get('nullable')
  if nullable is not None:
    message.nullable = nullable
  ordinal_position = oracle_column_object.get('ordinal_position')
  if ordinal_position is not None:
    message.ordinalPosition = ordinal_position
  precision = oracle_column_object.get('precision')
  if precision is not None:
    message.precision = precision
  primary_key = oracle_column_object.get('primary_key')
  if primary_key is not None:
    message.primaryKey = primary_key
  scale = oracle_column_object.get('scale')
  if scale is not None:
    message.scale = scale
  return message


def ParseOracleTable(messages, oracle_table_object, release_track):
  """Parses a raw oracle table json/yaml into the OracleTable message."""
  oracle_columns_msg_list = []
  for column in oracle_table_object.get('oracle_columns', []):
    oracle_columns_msg_list.append(
        ParseOracleColumn(messages, column, release_track))
  table_key = _GetRDBMSFieldName('table', release_track)
  table_name = oracle_table_object.get(table_key)
  if not table_name:
    raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
                                   table_key)
  return messages.OracleTable(
      table=table_name, oracleColumns=oracle_columns_msg_list)


def ParseOracleSchema(messages, oracle_schema_object, release_track):
  """Parses a raw oracle schema json/yaml into the OracleSchema message."""
  oracle_tables_msg_list = []
  for table in oracle_schema_object.get('oracle_tables', []):
    oracle_tables_msg_list.append(
        ParseOracleTable(messages, table, release_track))
  schema_key = _GetRDBMSFieldName('schema', release_track)
  schema_name = oracle_schema_object.get(schema_key)
  if not schema_name:
    raise ds_exceptions.ParseError('Cannot parse YAML: missing key "%s".' %
                                   schema_key)
  return messages.OracleSchema(
      schema=schema_name, oracleTables=oracle_tables_msg_list)


def ParseOracleSchemasListToOracleRdbmsMessage(messages,
                                               oracle_rdbms_data,
                                               release_track=base.ReleaseTrack
                                               .BETA):
  """Parses an object of type {oracle_schemas: [...]} into the OracleRdbms message."""
  oracle_schemas_raw = oracle_rdbms_data.get('oracle_schemas', [])
  oracle_schema_msg_list = []
  for schema in oracle_schemas_raw:
    oracle_schema_msg_list.append(
        ParseOracleSchema(messages, schema, release_track))

  oracle_rdbms_msg = messages.OracleRdbms(
      oracleSchemas=oracle_schema_msg_list)
  return oracle_rdbms_msg


def ParsePostgresqlColumn(messages, postgresql_column_object):
  """Parses a raw postgresql column json/yaml into the PostgresqlColumn message."""
  message = messages.PostgresqlColumn(
      column=postgresql_column_object.get('column', ''))
  data_type = postgresql_column_object.get('data_type')
  if data_type is not None:
    message.dataType = data_type
  length = postgresql_column_object.get('length')
  if length is not None:
    message.length = length
  precision = postgresql_column_object.get('precision')
  if precision is not None:
    message.precision = precision
  scale = postgresql_column_object.get('scale')
  if scale is not None:
    message.scale = scale
  primary_key = postgresql_column_object.get('primary_key')
  if primary_key is not None:
    message.primaryKey = primary_key
  nullable = postgresql_column_object.get('nullable')
  if nullable is not None:
    message.nullable = nullable
  ordinal_position = postgresql_column_object.get('ordinal_position')
  if ordinal_position is not None:
    message.ordinalPosition = ordinal_position
  return message


def ParsePostgresqlTable(messages, postgresql_table_object):
  """Parses a raw postgresql table json/yaml into the PostgresqlTable message."""
  postgresql_columns_msg_list = []
  for column in postgresql_table_object.get('postgresql_columns', []):
    postgresql_columns_msg_list.append(ParsePostgresqlColumn(messages, column))
  table_name = postgresql_table_object.get('table')
  if not table_name:
    raise ds_exceptions.ParseError('Cannot parse YAML: missing key "table".')
  return messages.PostgresqlTable(
      table=table_name, postgresqlColumns=postgresql_columns_msg_list)


def ParsePostgresqlSchema(messages, postgresql_schema_object):
  """Parses a raw postgresql schema json/yaml into the PostgresqlSchema message."""
  postgresql_tables_msg_list = []
  for table in postgresql_schema_object.get('postgresql_tables', []):
    postgresql_tables_msg_list.append(ParsePostgresqlTable(messages, table))
  schema_name = postgresql_schema_object.get('schema')
  if not schema_name:
    raise ds_exceptions.ParseError('Cannot parse YAML: missing key "schema".')
  return messages.PostgresqlSchema(
      schema=schema_name, postgresqlTables=postgresql_tables_msg_list)


def ParsePostgresqlSchemasListToPostgresqlRdbmsMessage(messages,
                                                       postgresql_rdbms_data):
  """Parses an object of type {postgresql_schemas: [...]} into the PostgresqlRdbms message."""
  postgresql_schemas_raw = postgresql_rdbms_data.get('postgresql_schemas', [])
  postgresql_schema_msg_list = []
  for schema in postgresql_schemas_raw:
    postgresql_schema_msg_list.append(ParsePostgresqlSchema(messages, schema))

  postgresql_rdbms_msg = messages.PostgresqlRdbms(
      postgresqlSchemas=postgresql_schema_msg_list)
  return postgresql_rdbms_msg


def ParseMongodbField(messages, mongodb_field_object):
  """Parses a raw mongodb field json/yaml into the MongodbField message."""
  return messages.MongodbField(field=mongodb_field_object.get('field', ''))


def ParseMongodbCollection(messages, mongodb_collection_object):
  """Parses a raw mongodb database json/yaml into the MongodbCollection message."""
  fields_raw = mongodb_collection_object.get('fields', [])
  fields_msg_list = []
  for field in fields_raw:
    fields_msg_list.append(ParseMongodbField(messages, field))
  return messages.MongodbCollection(
      collection=mongodb_collection_object.get('collection', ''),
      fields=fields_msg_list,
  )


def ParseMongodbDatabase(messages, mongodb_database_object):
  """Parses a raw mongodb database json/yaml into the MongodbDatabase message."""
  collections_raw = mongodb_database_object.get('collections', [])
  collections_msg_list = []
  for collection in collections_raw:
    collections_msg_list.append(ParseMongodbCollection(messages, collection))
  return messages.MongodbDatabase(collections=collections_msg_list)


def ParseMongodbCluster(messages, mongodb_cluster_object):
  """Parses a raw mongodb cluster json/yaml into the MongodbCluster message."""
  databases_raw = mongodb_cluster_object.get('databases', [])
  databases_msg_list = []
  for database in databases_raw:
    databases_msg_list.append(ParseMongodbDatabase(messages, database))
  return messages.MongodbCluster(databases=databases_msg_list)


def UpdateV1alpha1ToV1MaskFields(field_mask):
  """Updates field mask paths according to the v1alpha1 > v1 Datastream API change.

  This allows for backwards compatibility with the current client field
  mask.

  Args:
    field_mask: List[str], list of stream fields to update

  Returns:
    updated_field_mask: List[str] field mask with fields translated
      from v1alpha1 API to v1.
  """
  updated_field_mask = []
  for path in field_mask:
    field_to_translate = None
    for field in _UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1:
      if field in path:
        field_to_translate = field
        break
    if field_to_translate:
      updated_field_mask.append(
          path.replace(
              field_to_translate,
              _UPDATE_MASK_FIELD_TRANSLATION_V1ALPHA1_TO_V1[field_to_translate])
      )
    else:
      updated_field_mask.append(path)
  return updated_field_mask


def ParseMysqlRdbmsFileBeta(
    messages, mysql_rdbms_file, release_track=base.ReleaseTrack.BETA
):
  """Parses a mysql_rdbms_file into the MysqlRdbms message. deprecated."""

  data = console_io.ReadFromFileOrStdin(mysql_rdbms_file, binary=False)
  try:
    mysql_rdbms_head_data = yaml.load(data)
  except Exception as e:
    raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))

  mysql_rdbms_data = mysql_rdbms_head_data.get(
      'mysql_rdbms', mysql_rdbms_head_data
  )
  return ParseMysqlSchemasListToMysqlRdbmsMessage(
      messages, mysql_rdbms_data, release_track
  )


def ParseOracleRdbmsFileBeta(
    messages, oracle_rdbms_file, release_track=base.ReleaseTrack.BETA
):
  """Parses a oracle_rdbms_file into the OracleRdbms message. deprecated."""
  data = console_io.ReadFromFileOrStdin(oracle_rdbms_file, binary=False)
  try:
    oracle_rdbms_head_data = yaml.load(data)
  except Exception as e:
    raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))

  oracle_rdbms_data = oracle_rdbms_head_data.get(
      'oracle_rdbms', oracle_rdbms_head_data
  )
  return ParseOracleSchemasListToOracleRdbmsMessage(
      messages, oracle_rdbms_data, release_track
  )