HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/datastream/connection_profiles.py
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream connection profiles API."""


from apitools.base.py import list_pager
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
from googlecloudsdk.api_lib.datastream import util
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.calliope.arg_parsers import HostPort
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io


def GetConnectionProfileURI(resource):
  connection_profile = resources.REGISTRY.ParseRelativeName(
      resource.name,
      collection='datastream.projects.locations.connectionProfiles')
  return connection_profile.SelfLink()


class ConnectionProfilesClient:
  """Client for connection profiles service in the API."""

  def __init__(self, client=None, messages=None):
    self._client = client or util.GetClientInstance()
    self._messages = messages or util.GetMessagesModule()
    self._service = self._client.projects_locations_connectionProfiles
    self._resource_parser = util.GetResourceParser()

  def _ValidateArgs(self, args):
    self._ValidateSslConfigArgs(args)

  def _ValidateSslConfigArgs(self, args):
    """Validates Format of all SSL config args."""
    self._ValidateCertificateFormat(args.ca_certificate, 'CA certificate')
    self._ValidateCertificateFormat(args.client_certificate,
                                    'client certificate')
    self._ValidateCertificateFormat(args.client_key, 'client key')

    # Validation for all Postgresql SSL config fields.
    self._ValidateCertificateFormat(
        args.postgresql_ca_certificate, 'Postgresql CA certificate'
    )
    self._ValidateCertificateFormat(
        args.postgresql_client_certificate, 'Postgresql client certificate'
    )
    self._ValidateCertificateFormat(
        args.postgresql_client_key, 'Postgresql client private key'
    )

    # Validation for Oracle SSL config fields.
    self._ValidateCertificateFormat(
        args.oracle_ca_certificate, 'Oracle CA certificate'
    )

  def _ValidateCertificateFormat(self, certificate, name):
    if not certificate:
      return True
    cert = certificate.strip()
    cert_lines = cert.split('\n')
    if (not cert_lines[0].startswith('-----') or
        not cert_lines[-1].startswith('-----')):
      raise exceptions.InvalidArgumentException(
          name,
          'The certificate does not appear to be in PEM format: \n{0}'.format(
              cert))

  def _GetSslConfig(self, args):
    return self._messages.MysqlSslConfig(
        clientKey=args.client_key,
        clientCertificate=args.client_certificate,
        caCertificate=args.ca_certificate)

  def _GetMySqlProfile(self, args):
    ssl_config = self._GetSslConfig(args)
    return self._messages.MysqlProfile(
        hostname=args.mysql_hostname,
        port=args.mysql_port,
        username=args.mysql_username,
        password=args.mysql_password,
        secretManagerStoredPassword=args.mysql_secret_manager_stored_password,
        sslConfig=ssl_config)

  def _GetOracleProfile(self, args):
    ssl_config = self._GetOracleSslConfig(args)
    return self._messages.OracleProfile(
        hostname=args.oracle_hostname,
        port=args.oracle_port,
        username=args.oracle_username,
        password=args.oracle_password,
        secretManagerStoredPassword=args.oracle_secret_manager_stored_password,
        databaseService=args.database_service,
        oracleSslConfig=ssl_config)

  def _GetOracleSslConfig(self, args):
    """Returns a OracleSslConfig message based on the given args."""
    return self._messages.OracleSslConfig(
        caCertificate=args.oracle_ca_certificate,
        serverCertificateDistinguishedName=args.oracle_server_certificate_distinguished_name,
    )

  def _GetPostgresqlSslConfig(self, args):
    """Returns a PostgresqlSslConfig message based on the given args."""
    if args.postgresql_client_certificate or args.postgresql_client_key:
      return self._messages.PostgresqlSslConfig(
          serverAndClientVerification=self._messages.ServerAndClientVerification(
              clientCertificate=args.postgresql_client_certificate,
              clientKey=args.postgresql_client_key,
              caCertificate=args.postgresql_ca_certificate,
              serverCertificateHostname=args.postgresql_server_certificate_hostname,
          )
      )

    if args.postgresql_ca_certificate:
      return self._messages.PostgresqlSslConfig(
          serverVerification=self._messages.ServerVerification(
              caCertificate=args.postgresql_ca_certificate,
              serverCertificateHostname=args.postgresql_server_certificate_hostname,
          )
      )

    return None

  def _GetPostgresqlProfile(self, args):
    ssl_config = self._GetPostgresqlSslConfig(args)
    return self._messages.PostgresqlProfile(
        hostname=args.postgresql_hostname,
        port=args.postgresql_port,
        username=args.postgresql_username,
        password=args.postgresql_password,
        secretManagerStoredPassword=args.postgresql_secret_manager_stored_password,
        database=args.postgresql_database,
        sslConfig=ssl_config)

  def _GetSqlServerProfile(self, args):
    return self._messages.SqlServerProfile(
        hostname=args.sqlserver_hostname,
        port=args.sqlserver_port,
        username=args.sqlserver_username,
        password=args.sqlserver_password,
        secretManagerStoredPassword=args.sqlserver_secret_manager_stored_password,
        database=args.sqlserver_database,
    )

  def _GetSalesforceProfile(self, args):
    if args.salesforce_oauth2_client_id:
      return self._messages.SalesforceProfile(
          domain=args.salesforce_domain,
          oauth2ClientCredentials=self._messages.Oauth2ClientCredentials(
              clientId=args.salesforce_oauth2_client_id,
              clientSecret=args.salesforce_oauth2_client_secret,
              secretManagerStoredClientSecret=args.salesforce_secret_manager_stored_oauth2_client_secret,
          ),
      )
    else:
      return self._messages.SalesforceProfile(
          domain=args.salesforce_domain,
          userCredentials=self._messages.UserCredentials(
              username=args.salesforce_username,
              password=args.salesforce_password,
              secretManagerStoredPassword=args.salesforce_secret_manager_stored_password,
              securityToken=args.salesforce_security_token,
              secretManagerStoredSecurityToken=args.salesforce_secret_manager_stored_security_token,
          ),
      )

  def _GetGCSProfile(self, args, release_track):
    # TODO(b/207467120): remove bucket_name arg check.
    if release_track == base.ReleaseTrack.BETA:
      bucket = args.bucket_name
    else:
      bucket = args.bucket

    gcs_profile = self._messages.GcsProfile(bucket=bucket)
    gcs_profile.rootPath = args.root_path if args.root_path else '/'
    return gcs_profile

  def _GetMongodbProfile(self, args):
    """Returns the MongoDB profile message based on the given args."""
    addresses = []
    for host_address in args.mongodb_host_addresses:
      if args.mongodb_srv_connection_format:
        addresses.append(
            self._messages.HostAddress(hostname=host_address)
        )
      else:
        hostport = HostPort.Parse(host_address)
        addresses.append(
            self._messages.HostAddress(
                hostname=hostport.host, port=int(hostport.port)
            )
        )
    profile = self._messages.MongodbProfile(
        hostAddresses=addresses,
        username=args.mongodb_username,
        replicaSet=args.mongodb_replica_set,
        password=args.mongodb_password,
        secretManagerStoredPassword=args.mongodb_secret_manager_stored_password,
    )
    if (
        args.mongodb_direct_connection
        and not args.mongodb_standard_connection_format
    ):
      raise exceptions.InvalidArgumentException(
          'mongodb-direct-connection',
          'mongodb direct connection can only be used with the standard'
          ' connection format.',
      )
    if args.mongodb_srv_connection_format:
      profile.srvConnectionFormat = {}
    if args.mongodb_standard_connection_format:
      profile.standardConnectionFormat = (
          self._messages.StandardConnectionFormat(
              directConnection=args.mongodb_direct_connection
          )
      )
    if args.mongodb_tls:
      profile.sslConfig = {}
      if args.mongodb_ca_certificate:
        profile.sslConfig.caCertificate = args.mongodb_ca_certificate
    return profile

  def _ParseSslConfig(self, data):
    return self._messages.MysqlSslConfig(
        clientKey=data.get('client_key'),
        clientCertificate=data.get('client_certificate'),
        caCertificate=data.get('ca_certificate'))

  def _ParseMySqlProfile(self, data):
    if not data:
      return {}
    ssl_config = self._ParseSslConfig(data)
    return self._messages.MysqlProfile(
        hostname=data.get('hostname'),
        port=data.get('port'),
        username=data.get('username'),
        password=data.get('password'),
        sslConfig=ssl_config)

  def _ParseOracleProfile(self, data):
    if not data:
      return {}
    return self._messages.OracleProfile(
        hostname=data.get('hostname'),
        port=data.get('port'),
        username=data.get('username'),
        password=data.get('password'),
        databaseService=data.get('database_service'))

  def _ParsePostgresqlProfile(self, data):
    if not data:
      return {}
    return self._messages.PostgresqlProfile(
        hostname=data.get('hostname'),
        port=data.get('port'),
        username=data.get('username'),
        password=data.get('password'),
        database=data.get('database'))

  def _ParseSqlServerProfile(self, data):
    if not data:
      return {}
    return self._messages.SqlServerProfile(
        hostname=data.get('hostname'),
        port=data.get('port'),
        username=data.get('username'),
        password=data.get('password'),
        database=data.get('database'),
    )

  def _ParseGCSProfile(self, data):
    if not data:
      return {}
    return self._messages.GcsProfile(
        bucket=data.get('bucket_name'), rootPath=data.get('root_path'))

  def _GetForwardSshTunnelConnectivity(self, args):
    return self._messages.ForwardSshTunnelConnectivity(
        hostname=args.forward_ssh_hostname,
        port=args.forward_ssh_port,
        username=args.forward_ssh_username,
        privateKey=args.forward_ssh_private_key,
        password=args.forward_ssh_password)

  def _GetConnectionProfile(self, cp_type, connection_profile_id, args,
                            release_track):
    """Returns a connection profile according to type."""
    labels = labels_util.ParseCreateArgs(
        args, self._messages.ConnectionProfile.LabelsValue)
    connection_profile_obj = self._messages.ConnectionProfile(
        name=connection_profile_id, labels=labels,
        displayName=args.display_name)

    if cp_type == 'MYSQL':
      connection_profile_obj.mysqlProfile = self._GetMySqlProfile(args)
    elif cp_type == 'ORACLE':
      connection_profile_obj.oracleProfile = self._GetOracleProfile(args)
    elif cp_type == 'POSTGRESQL':
      connection_profile_obj.postgresqlProfile = self._GetPostgresqlProfile(
          args)
    elif cp_type == 'SQLSERVER':
      connection_profile_obj.sqlServerProfile = self._GetSqlServerProfile(args)
    elif cp_type == 'GOOGLE-CLOUD-STORAGE':
      connection_profile_obj.gcsProfile = self._GetGCSProfile(
          args, release_track)
    elif cp_type == 'BIGQUERY':
      connection_profile_obj.bigqueryProfile = self._messages.BigQueryProfile()
    elif cp_type == 'SALESFORCE':
      connection_profile_obj.salesforceProfile = self._GetSalesforceProfile(
          args
      )
    elif cp_type == 'MONGODB':
      connection_profile_obj.mongodbProfile = self._GetMongodbProfile(args)
    else:
      raise exceptions.InvalidArgumentException(
          cp_type,
          'The connection profile type {0} is either unknown or not supported'
          ' yet.'.format(cp_type),
      )

    # TODO(b/207467120): deprecate BETA client.
    if release_track == base.ReleaseTrack.BETA:
      private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
    else:
      private_connectivity_ref = args.CONCEPTS.private_connection.Parse()

    if private_connectivity_ref:
      connection_profile_obj.privateConnectivity = (
          self._messages.PrivateConnectivity(
              privateConnection=private_connectivity_ref.RelativeName()
          )
      )
    elif args.forward_ssh_hostname:
      connection_profile_obj.forwardSshConnectivity = (
          self._GetForwardSshTunnelConnectivity(args)
      )
    elif args.static_ip_connectivity:
      connection_profile_obj.staticServiceIpConnectivity = {}

    return connection_profile_obj

  def _ParseConnectionProfileObjectFile(
      self, connection_profile_object_file, release_track
  ):
    """Parses a connection-profile-file into the ConnectionProfile message."""
    if release_track != base.ReleaseTrack.BETA:
      return util.ParseMessageAndValidateSchema(
          connection_profile_object_file,
          'ConnectionProfile',
          self._messages.ConnectionProfile,
      )

    data = console_io.ReadFromFileOrStdin(
        connection_profile_object_file, binary=False)
    try:
      connection_profile_data = yaml.load(data)
    except Exception as e:
      raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))

    display_name = connection_profile_data.get('display_name')
    labels = connection_profile_data.get('labels')
    connection_profile_msg = self._messages.ConnectionProfile(
        displayName=display_name,
        labels=labels)

    oracle_profile = self._ParseOracleProfile(
        connection_profile_data.get('oracle_profile', {}))
    mysql_profile = self._ParseMySqlProfile(
        connection_profile_data.get('mysql_profile', {}))
    postgresql_profile = self._ParsePostgresqlProfile(
        connection_profile_data.get('postgresql_profile', {}))
    sqlserver_profile = self._ParseSqlServerProfile(
        connection_profile_data.get('sqlserver_profile', {})
    )
    gcs_profile = self._ParseGCSProfile(
        connection_profile_data.get('gcs_profile', {}))
    if oracle_profile:
      connection_profile_msg.oracleProfile = oracle_profile
    elif mysql_profile:
      connection_profile_msg.mysqlProfile = mysql_profile
    elif postgresql_profile:
      connection_profile_msg.postgresqlProfile = postgresql_profile
    elif sqlserver_profile:
      connection_profile_msg.sqlServerProfile = sqlserver_profile
    elif gcs_profile:
      connection_profile_msg.gcsProfile = gcs_profile

    if 'static_service_ip_connectivity' in connection_profile_data:
      connection_profile_msg.staticServiceIpConnectivity = (
          connection_profile_data.get('static_service_ip_connectivity')
      )
    elif 'forward_ssh_connectivity' in connection_profile_data:
      connection_profile_msg.forwardSshConnectivity = (
          connection_profile_data.get('forward_ssh_connectivity')
      )
    elif 'private_connectivity' in connection_profile_data:
      connection_profile_msg.privateConnectivity = connection_profile_data.get(
          'private_connectivity'
      )
    else:
      raise ds_exceptions.ParseError(
          'Cannot parse YAML: missing connectivity method.'
      )

    return connection_profile_msg

  def _UpdateForwardSshTunnelConnectivity(
      self, connection_profile, args, update_fields
  ):
    """Updates Forward SSH tunnel connectivity config."""
    if args.IsSpecified('forward_ssh_hostname'):
      connection_profile.forwardSshConnectivity.hostname = (
          args.forward_ssh_hostname
      )
      update_fields.append('forwardSshConnectivity.hostname')
    if args.IsSpecified('forward_ssh_port'):
      connection_profile.forwardSshConnectivity.port = args.forward_ssh_port
      update_fields.append('forwardSshConnectivity.port')
    if args.IsSpecified('forward_ssh_username'):
      connection_profile.forwardSshConnectivity.username = (
          args.forward_ssh_username
      )
      update_fields.append('forwardSshConnectivity.username')
    if args.IsSpecified('forward_ssh_private_key'):
      connection_profile.forwardSshConnectivity.privateKey = (
          args.forward_ssh_private_key
      )
      update_fields.append('forwardSshConnectivity.privateKey')
    if args.IsSpecified('forward_ssh_password'):
      connection_profile.forwardSshConnectivity.privateKey = (
          args.forward_ssh_password
      )
      update_fields.append('forwardSshConnectivity.password')

  def _UpdateGCSProfile(
      self, connection_profile, release_track, args, update_fields
  ):
    """Updates GOOGLE CLOUD STORAGE connection profile."""
    # TODO(b/207467120): remove bucket_name arg check.
    if release_track == base.ReleaseTrack.BETA and args.IsSpecified(
        'bucket_name'
    ):
      connection_profile.gcsProfile.bucket = args.bucket_name
      update_fields.append('gcsProfile.bucket')
    if release_track == base.ReleaseTrack.GA and args.IsSpecified('bucket'):
      connection_profile.gcsProfile.bucket = args.bucket
      update_fields.append('gcsProfile.bucket')
    if args.IsSpecified('root_path'):
      connection_profile.gcsProfile.rootPath = args.root_path
      update_fields.append('gcsProfile.rootPath')

  def _UpdateOracleProfile(self,
                           connection_profile,
                           args,
                           update_fields):
    """Updates Oracle connection profile."""
    if args.IsSpecified('oracle_hostname'):
      connection_profile.oracleProfile.hostname = args.oracle_hostname
      update_fields.append('oracleProfile.hostname')
    if args.IsSpecified('oracle_port'):
      connection_profile.oracleProfile.port = args.oracle_port
      update_fields.append('oracleProfile.port')
    if args.IsSpecified('oracle_username'):
      connection_profile.oracleProfile.username = args.oracle_username
      update_fields.append('oracleProfile.username')
    if args.IsSpecified('oracle_password') or args.IsSpecified(
        'oracle_secret_manager_stored_password'
    ):
      connection_profile.oracleProfile.password = args.oracle_password
      connection_profile.oracleProfile.secretManagerStoredPassword = (
          args.oracle_secret_manager_stored_password
      )
      update_fields.append('oracleProfile.password')
      update_fields.append('oracleProfile.secretManagerStoredPassword')
    if args.IsSpecified('database_service'):
      connection_profile.oracleProfile.databaseService = args.database_service
      update_fields.append('oracleProfile.databaseService')

  def _UpdateMysqlSslConfig(self, connection_profile, args, update_fields):
    """Updates Mysql SSL config."""
    if args.IsSpecified('client_key'):
      connection_profile.mysqlProfile.sslConfig.clientKey = args.client_key
      update_fields.append('mysqlProfile.sslConfig.clientKey')
    if args.IsSpecified('client_certificate'):
      connection_profile.mysqlProfile.sslConfig.clientCertificate = (
          args.client_certificate
      )
      update_fields.append('mysqlProfile.sslConfig.clientCertificate')
    if args.IsSpecified('ca_certificate'):
      connection_profile.mysqlProfile.sslConfig.caCertificate = (
          args.ca_certificate
      )
      update_fields.append('mysqlProfile.sslConfig.caCertificate')

  def _UpdateMySqlProfile(self, connection_profile, args, update_fields):
    """Updates MySQL connection profile."""
    if args.IsSpecified('mysql_hostname'):
      connection_profile.mysqlProfile.hostname = args.mysql_hostname
      update_fields.append('mysqlProfile.hostname')
    if args.IsSpecified('mysql_port'):
      connection_profile.mysqlProfile.port = args.mysql_port
      update_fields.append('mysqlProfile.port')
    if args.IsSpecified('mysql_username'):
      connection_profile.mysqlProfile.username = args.mysql_username
      update_fields.append('mysqlProfile.username')
    if args.IsSpecified('mysql_password') or args.IsSpecified(
        'mysql_secret_manager_stored_password'
    ):
      connection_profile.mysqlProfile.password = args.mysql_password
      connection_profile.mysqlProfile.secretManagerStoredPassword = (
          args.mysql_secret_manager_stored_password
      )
      update_fields.append('mysqlProfile.password')
      update_fields.append('mysqlProfile.secretManagerStoredPassword')

    self._UpdateMysqlSslConfig(connection_profile, args, update_fields)

  def _UpdatePostgresqlSslConfig(self, connection_profile, args, update_fields):
    """Updates Postgresql SSL config."""
    if args.IsSpecified('postgresql_client_certificate'):
      connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate = (
          args.postgresql_client_certificate
      )
      update_fields.append(
          'postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate'
      )

    if args.IsSpecified('postgresql_client_key'):
      connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientKey = (
          args.postgresql_client_key
      )
      update_fields.append(
          'postgresqlProfile.sslConfig.serverAndClientVerification.clientKey'
      )

    if args.IsSpecified('postgresql_ca_certificate'):
      if connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification:
        connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate = (
            args.postgresql_ca_certificate
        )
        update_fields.append(
            'postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate'
        )
      else:
        connection_profile.postgresqlProfile.sslConfig.serverVerification.caCertificate = (
            args.postgresql_ca_certificate
        )
        update_fields.append(
            'postgresqlProfile.sslConfig.serverVerification.caCertificate'
        )
    if args.IsSpecified('postgresql_server_certificate_hostname'):
      if (
          connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification
      ):
        connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname = (
            args.postgresql_server_certificate_hostname
        )
        update_fields.append(
            'postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname'
        )
      else:
        connection_profile.postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname = (
            args.postgresql_server_certificate_hostname
        )
        update_fields.append(
            'postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname'
        )

  def _UpdatePostgresqlProfile(self, connection_profile, args, update_fields):
    """Updates Postgresql connection profile."""
    if args.IsSpecified('postgresql_hostname'):
      connection_profile.postgresqlProfile.hostname = args.postgresql_hostname
      update_fields.append('postgresqlProfile.hostname')
    if args.IsSpecified('postgresql_port'):
      connection_profile.postgresqlProfile.port = args.postgresql_port
      update_fields.append('postgresqlProfile.port')
    if args.IsSpecified('postgresql_username'):
      connection_profile.postgresqlProfile.username = args.postgresql_username
      update_fields.append('postgresqlProfile.username')
    if args.IsSpecified('postgresql_password') or args.IsSpecified(
        'postgresql_secret_manager_stored_password'
    ):
      connection_profile.postgresqlProfile.password = args.postgresql_password
      connection_profile.postgresqlProfile.secretManagerStoredPassword = (
          args.postgresql_secret_manager_stored_password
      )
      update_fields.append('postgresqlProfile.password')
      update_fields.append('postgresqlProfile.secretManagerStoredPassword')
    if args.IsSpecified('postgresql_database'):
      connection_profile.postgresqlProfile.database = args.postgresql_database
      update_fields.append('postgresqlProfile.database')

    self._UpdatePostgresqlSslConfig(connection_profile, args, update_fields)

  def _UpdateSqlServerProfile(self, connection_profile, args, update_fields):
    """Updates SqlServer connection profile."""
    if args.IsSpecified('sqlserver_hostname'):
      connection_profile.sqlServerProfile.hostname = args.sqlserver_hostname
      update_fields.append('sqlServerProfile.hostname')
    if args.IsSpecified('sqlserver_port'):
      connection_profile.sqlServerProfile.port = args.sqlserver_port
      update_fields.append('sqlServerProfile.port')
    if args.IsSpecified('sqlserver_username'):
      connection_profile.sqlServerProfile.username = args.sqlserver_username
      update_fields.append('sqlServerProfile.username')
    if args.IsSpecified('sqlserver_password') or args.IsSpecified(
        'sqlserver_secret_manager_stored_password'
    ):
      connection_profile.sqlServerProfile.password = args.sqlserver_password
      connection_profile.sqlServerProfile.secretManagerStoredPassword = (
          args.sqlserver_secret_manager_stored_password
      )
      update_fields.append('sqlServerProfile.password')
      update_fields.append('sqlServerProfile.secretManagerStoredPassword')
    if args.IsSpecified('sqlserver_database'):
      connection_profile.sqlServerProfile.database = args.sqlserver_database
      update_fields.append('sqlServerProfile.database')

  def _UpdateSalesforceProfile(self, connection_profile, args, update_fields):
    """Updates Salesforce connection profile."""
    if args.IsSpecified('salesforce_domain'):
      connection_profile.salesforceProfile.domain = args.salesforce_domain
      update_fields.append('salesforceProfile.domain')
    if args.IsSpecified('salesforce_username'):
      connection_profile.salesforceProfile.userCredentials.username = (
          args.salesforce_username
      )
      update_fields.append('salesforceProfile.userCredentials.username')
    if args.IsSpecified('salesforce_password') or args.IsSpecified(
        'salesforce_secret_manager_stored_password'
    ):
      connection_profile.salesforceProfile.userCredentials.password = (
          args.salesforce_password
      )
      connection_profile.salesforceProfile.userCredentials.secretManagerStoredPassword = (
          args.salesforce_secret_manager_stored_password
      )
      update_fields.append('salesforceProfile.userCredentials.password')
      update_fields.append(
          'salesforceProfile.userCredentials.secretManagerStoredPassword'
      )

    if args.IsSpecified('salesforce_security_token') or args.IsSpecified(
        'salesforce_secret_manager_stored_security_token'
    ):
      connection_profile.salesforceProfile.userCredentials.securityToken = (
          args.salesforce_security_token
      )
      connection_profile.salesforceProfile.userCredentials.secretManagerStoredSecurityToken = (
          args.salesforce_secret_manager_stored_security_token
      )
      update_fields.append('salesforceProfile.userCredentials.securityToken')
      update_fields.append(
          'salesforceProfile.userCredentials.secretManagerStoredSecurityToken'
      )

    if args.IsSpecified('salesforce_oauth2_client_id'):
      connection_profile.salesforceProfile.oauth2ClientCredentials.clientId = (
          args.salesforce_oauth2_client_id
      )
      update_fields.append('salesforceProfile.oauth2ClientCredentials.clientId')
    if args.IsSpecified('salesforce_oauth2_client_secret') or args.IsSpecified(
        'salesforce_secret_manager_stored_oauth2_client_secret'
    ):
      connection_profile.salesforceProfile.oauth2ClientCredentials.clientSecret = (
          args.salesforce_oauth2_client_secret
      )
      connection_profile.salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret = (
          args.salesforce_secret_manager_stored_oauth2_client_secret
      )
      update_fields.append(
          'salesforceProfile.oauth2ClientCredentials.clientSecret'
      )
      update_fields.append(
          'salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret'
      )

  def _UpdateMongodbProfile(self, connection_profile, args, update_fields):
    """Updates MongoDB connection profile."""
    if args.IsSpecified('mongodb_host_addresses'):
      addresses = []
      for host_address in args.mongodb_host_addresses:
        if args.mongodb_srv_connection_format:
          addresses.append(
              self._messages.HostAddress(hostname=host_address)
          )
        else:
          hostname, port = host_address.split(':')
          addresses.append(
              self._messages.HostAddress(hostname=hostname, port=int(port))
          )
      connection_profile.mongodbProfile.hostAddresses = addresses
      update_fields.append('monogodbProfile.hostAddresses')
    if args.IsSpecified('mongodb_replica_set'):
      connection_profile.mongodbProfile.replicaSet = args.mongodb_replica_set
      update_fields.append('mongodbProfile.replicaSet')
    if args.IsSpecified('mongodb_srv_connection_format') or args.IsSpecified(
        'mongodb_standard_connection_format'
    ):
      if args.mongodb_srv_connection_format:
        connection_profile.mongodbProfile.srvConnectionFormat = {}
      if args.mongodb_standard_connection_format:
        connection_profile.mongodbProfile.standardConnectionFormat = {}
      update_fields.append('mongodbProfile.srvConnectionFormat')
      update_fields.append('mongodbProfile.standardConnectionFormat')
    if args.IsSpecified('mongodb_username'):
      connection_profile.mongodbProfile.username = args.mongodb_username
      update_fields.append('mongodbProfile.username')
    if args.IsSpecified('mongodb_password') or args.IsSpecified(
        'mongodb_secret_manager_stored_password'
    ):
      connection_profile.mongodbProfile.password = args.mongodb_password
      connection_profile.mongodbProfile.secretManagerStoredPassword = (
          args.mongodb_secret_manager_stored_password
      )
      update_fields.append('mongodbProfile.password')
      update_fields.append('mongodbProfile.secretManagerStoredPassword')

  def _GetExistingConnectionProfile(self, name):
    get_req = (
        self._messages.DatastreamProjectsLocationsConnectionProfilesGetRequest(
            name=name
        )
    )
    return self._service.Get(get_req)

  def _UpdateLabels(self, connection_profile, args):
    """Updates labels of the connection profile."""
    add_labels = labels_util.GetUpdateLabelsDictFromArgs(args)
    remove_labels = labels_util.GetRemoveLabelsListFromArgs(args)
    value_type = self._messages.ConnectionProfile.LabelsValue
    update_result = labels_util.Diff(
        additions=add_labels,
        subtractions=remove_labels,
        clear=args.clear_labels
    ).Apply(value_type, connection_profile.labels)
    if update_result.needs_update:
      connection_profile.labels = update_result.labels

  def _GetUpdatedConnectionProfile(self, connection_profile, cp_type,
                                   release_track, args):
    """Returns updated connection profile and list of updated fields."""
    update_fields = []
    if args.IsSpecified('display_name'):
      connection_profile.displayName = args.display_name
      update_fields.append('displayName')

    if cp_type == 'MYSQL':
      self._UpdateMySqlProfile(
          connection_profile, args, update_fields)
    elif cp_type == 'ORACLE':
      self._UpdateOracleProfile(connection_profile, args, update_fields)
    elif cp_type == 'POSTGRESQL':
      self._UpdatePostgresqlProfile(connection_profile, args, update_fields)
    elif cp_type == 'SQLSERVER':
      self._UpdateSqlServerProfile(connection_profile, args, update_fields)
    elif cp_type == 'SALESFORCE':
      self._UpdateSalesforceProfile(connection_profile, args, update_fields)
    elif cp_type == 'GOOGLE-CLOUD-STORAGE':
      self._UpdateGCSProfile(
          connection_profile, release_track, args, update_fields
      )
    elif cp_type == 'BIGQUERY':
      # There are currently no parameters that can be updated in a bigquery CP.
      pass
    elif cp_type == 'MONGODB':
      self._UpdateMongodbProfile(connection_profile, args, update_fields)
    else:
      raise exceptions.InvalidArgumentException(
          cp_type,
          'The connection profile type {0} is either unknown or not supported'
          ' yet.'.format(cp_type),
      )

    # TODO(b/207467120): deprecate BETA client.
    if release_track == base.ReleaseTrack.BETA:
      private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
    else:
      private_connectivity_ref = args.CONCEPTS.private_connection.Parse()

    if private_connectivity_ref:
      connection_profile.privateConnectivity = (
          self._messages.PrivateConnectivity(
              privateConnectionName=private_connectivity_ref.RelativeName()
          )
      )
      update_fields.append('privateConnectivity')
    elif args.forward_ssh_hostname:
      self._UpdateForwardSshTunnelConnectivity(
          connection_profile, args, update_fields
      )
    elif args.static_ip_connectivity:
      connection_profile.staticServiceIpConnectivity = {}
      update_fields.append('staticServiceIpConnectivity')

    self._UpdateLabels(connection_profile, args)
    return connection_profile, update_fields

  def Create(self,
             parent_ref,
             connection_profile_id,
             cp_type,
             release_track,
             args=None):
    """Creates a connection profile.

    Args:
      parent_ref: a Resource reference to a parent datastream.projects.locations
        resource for this connection profile.
      connection_profile_id: str, the name of the resource to create.
      cp_type: str, the type of the connection profile ('MYSQL', ''
      release_track: Some arguments are added based on the command release
        track.
      args: argparse.Namespace, The arguments that this command was invoked
        with.

    Returns:
      Operation: the operation for creating the connection profile.
    """
    self._ValidateArgs(args)

    connection_profile = self._GetConnectionProfile(cp_type,
                                                    connection_profile_id, args,
                                                    release_track)
    # TODO(b/207467120): only use flags from args.
    force = False
    if release_track == base.ReleaseTrack.BETA or args.force:
      force = True

    request_id = util.GenerateRequestId()
    create_req_type = (
        self._messages.DatastreamProjectsLocationsConnectionProfilesCreateRequest
    )
    create_req = create_req_type(
        connectionProfile=connection_profile,
        connectionProfileId=connection_profile.name,
        parent=parent_ref,
        requestId=request_id,
        force=force)

    return self._service.Create(create_req)

  def Update(self, name, cp_type, release_track, args=None):
    """Updates a connection profile.

    Args:
      name: str, the reference of the connection profile to
          update.
      cp_type: str, the type of the connection profile ('MYSQL', 'ORACLE')
      release_track: Some arguments are added based on the command release
        track.
      args: argparse.Namespace, The arguments that this command was
          invoked with.

    Returns:
      Operation: the operation for updating the connection profile.
    """
    self._ValidateArgs(args)

    current_cp = self._GetExistingConnectionProfile(name)

    updated_cp, update_fields = self._GetUpdatedConnectionProfile(
        current_cp, cp_type, release_track, args)

    # TODO(b/207467120): only use flags from args.
    force = False
    if release_track == base.ReleaseTrack.BETA or args.force:
      force = True

    request_id = util.GenerateRequestId()
    update_req_type = (
        self._messages.DatastreamProjectsLocationsConnectionProfilesPatchRequest
    )
    update_req = update_req_type(
        connectionProfile=updated_cp,
        name=updated_cp.name,
        updateMask=','.join(update_fields),
        requestId=request_id,
        force=force,
    )

    return self._service.Patch(update_req)

  def List(self, project_id, args):
    """Get the list of connection profiles in a project.

    Args:
      project_id: The project ID to retrieve
      args: parsed command line arguments

    Returns:
      An iterator over all the matching connection profiles.
    """
    location_ref = self._resource_parser.Create(
        'datastream.projects.locations',
        projectsId=project_id,
        locationsId=args.location,
    )

    list_req_type = (
        self._messages.DatastreamProjectsLocationsConnectionProfilesListRequest
    )
    list_req = list_req_type(
        parent=location_ref.RelativeName(),
        filter=args.filter,
        orderBy=','.join(args.sort_by) if args.sort_by else None,
    )

    return list_pager.YieldFromList(
        service=self._client.projects_locations_connectionProfiles,
        request=list_req,
        limit=args.limit,
        batch_size=args.page_size,
        field='connectionProfiles',
        batch_size_attribute='pageSize')

  def Discover(self, parent_ref, release_track, args):
    """Discover a connection profile.

    Args:
      parent_ref: a Resource reference to a parent datastream.projects.locations
        resource for this connection profile.
      release_track: Some arguments are added based on the command release
        track.
      args: argparse.Namespace, The arguments that this command was invoked
        with.

    Returns:
      Operation: the operation for discovering the connection profile.
    """
    request = self._messages.DiscoverConnectionProfileRequest()
    if args.connection_profile_name:
      connection_profile_ref = args.CONCEPTS.connection_profile_name.Parse()
      request.connectionProfileName = connection_profile_ref.RelativeName()
    elif args.connection_profile_object_file:
      request.connectionProfile = self._ParseConnectionProfileObjectFile(
          args.connection_profile_object_file, release_track
      )

    if args.recursive or args.full_hierarchy:
      request.fullHierarchy = True
    elif args.recursive_depth:
      request.hierarchyDepth = (int)(args.recursive_depth)
    elif args.hierarchy_depth:
      request.hierarchyDepth = (int)(args.hierarchy_depth)
    else:
      request.fullHierarchy = False

    if args.mysql_rdbms_file:
      request.mysqlRdbms = util.ParseMysqlRdbmsFile(self._messages,
                                                    args.mysql_rdbms_file,
                                                    release_track)
    elif args.oracle_rdbms_file:
      request.oracleRdbms = util.ParseOracleRdbmsFile(self._messages,
                                                      args.oracle_rdbms_file,
                                                      release_track)
    elif args.postgresql_rdbms_file:
      request.postgresqlRdbms = util.ParsePostgresqlRdbmsFile(
          self._messages, args.postgresql_rdbms_file)
    elif args.sqlserver_rdbms_file:
      request.sqlServerRdbms = util.ParseSqlServerRdbmsFile(
          self._messages, args.sqlserver_rdbms_file
      )
    discover_req_type = (
        self._messages.DatastreamProjectsLocationsConnectionProfilesDiscoverRequest
    )
    discover_req = discover_req_type(
        discoverConnectionProfileRequest=request, parent=parent_ref)
    return self._service.Discover(discover_req)

  def GetUri(self, name):
    """Get the URL string for a connection profile.

    Args:
      name: connection profile's full name.

    Returns:
      URL of the connection profile resource
    """

    uri = self._resource_parser.ParseRelativeName(
        name, collection='datastream.projects.locations.connectionProfiles')
    return uri.SelfLink()