File: //snap/google-cloud-cli/396/lib/googlecloudsdk/api_lib/datastream/connection_profiles.py
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Cloud Datastream connection profiles API."""
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.datastream import exceptions as ds_exceptions
from googlecloudsdk.api_lib.datastream import util
from googlecloudsdk.calliope import base
from googlecloudsdk.calliope import exceptions
from googlecloudsdk.calliope.arg_parsers import HostPort
from googlecloudsdk.command_lib.util.args import labels_util
from googlecloudsdk.core import resources
from googlecloudsdk.core import yaml
from googlecloudsdk.core.console import console_io
def GetConnectionProfileURI(resource):
connection_profile = resources.REGISTRY.ParseRelativeName(
resource.name,
collection='datastream.projects.locations.connectionProfiles')
return connection_profile.SelfLink()
class ConnectionProfilesClient:
"""Client for connection profiles service in the API."""
def __init__(self, client=None, messages=None):
self._client = client or util.GetClientInstance()
self._messages = messages or util.GetMessagesModule()
self._service = self._client.projects_locations_connectionProfiles
self._resource_parser = util.GetResourceParser()
def _ValidateArgs(self, args):
self._ValidateSslConfigArgs(args)
def _ValidateSslConfigArgs(self, args):
"""Validates Format of all SSL config args."""
self._ValidateCertificateFormat(args.ca_certificate, 'CA certificate')
self._ValidateCertificateFormat(args.client_certificate,
'client certificate')
self._ValidateCertificateFormat(args.client_key, 'client key')
# Validation for all Postgresql SSL config fields.
self._ValidateCertificateFormat(
args.postgresql_ca_certificate, 'Postgresql CA certificate'
)
self._ValidateCertificateFormat(
args.postgresql_client_certificate, 'Postgresql client certificate'
)
self._ValidateCertificateFormat(
args.postgresql_client_key, 'Postgresql client private key'
)
# Validation for Oracle SSL config fields.
self._ValidateCertificateFormat(
args.oracle_ca_certificate, 'Oracle CA certificate'
)
def _ValidateCertificateFormat(self, certificate, name):
if not certificate:
return True
cert = certificate.strip()
cert_lines = cert.split('\n')
if (not cert_lines[0].startswith('-----') or
not cert_lines[-1].startswith('-----')):
raise exceptions.InvalidArgumentException(
name,
'The certificate does not appear to be in PEM format: \n{0}'.format(
cert))
def _GetSslConfig(self, args):
return self._messages.MysqlSslConfig(
clientKey=args.client_key,
clientCertificate=args.client_certificate,
caCertificate=args.ca_certificate)
def _GetMySqlProfile(self, args):
ssl_config = self._GetSslConfig(args)
return self._messages.MysqlProfile(
hostname=args.mysql_hostname,
port=args.mysql_port,
username=args.mysql_username,
password=args.mysql_password,
secretManagerStoredPassword=args.mysql_secret_manager_stored_password,
sslConfig=ssl_config)
def _GetOracleProfile(self, args):
ssl_config = self._GetOracleSslConfig(args)
return self._messages.OracleProfile(
hostname=args.oracle_hostname,
port=args.oracle_port,
username=args.oracle_username,
password=args.oracle_password,
secretManagerStoredPassword=args.oracle_secret_manager_stored_password,
databaseService=args.database_service,
oracleSslConfig=ssl_config)
def _GetOracleSslConfig(self, args):
"""Returns a OracleSslConfig message based on the given args."""
return self._messages.OracleSslConfig(
caCertificate=args.oracle_ca_certificate,
serverCertificateDistinguishedName=args.oracle_server_certificate_distinguished_name,
)
def _GetPostgresqlSslConfig(self, args):
"""Returns a PostgresqlSslConfig message based on the given args."""
if args.postgresql_client_certificate or args.postgresql_client_key:
return self._messages.PostgresqlSslConfig(
serverAndClientVerification=self._messages.ServerAndClientVerification(
clientCertificate=args.postgresql_client_certificate,
clientKey=args.postgresql_client_key,
caCertificate=args.postgresql_ca_certificate,
serverCertificateHostname=args.postgresql_server_certificate_hostname,
)
)
if args.postgresql_ca_certificate:
return self._messages.PostgresqlSslConfig(
serverVerification=self._messages.ServerVerification(
caCertificate=args.postgresql_ca_certificate,
serverCertificateHostname=args.postgresql_server_certificate_hostname,
)
)
return None
def _GetPostgresqlProfile(self, args):
ssl_config = self._GetPostgresqlSslConfig(args)
return self._messages.PostgresqlProfile(
hostname=args.postgresql_hostname,
port=args.postgresql_port,
username=args.postgresql_username,
password=args.postgresql_password,
secretManagerStoredPassword=args.postgresql_secret_manager_stored_password,
database=args.postgresql_database,
sslConfig=ssl_config)
def _GetSqlServerProfile(self, args):
return self._messages.SqlServerProfile(
hostname=args.sqlserver_hostname,
port=args.sqlserver_port,
username=args.sqlserver_username,
password=args.sqlserver_password,
secretManagerStoredPassword=args.sqlserver_secret_manager_stored_password,
database=args.sqlserver_database,
)
def _GetSalesforceProfile(self, args):
if args.salesforce_oauth2_client_id:
return self._messages.SalesforceProfile(
domain=args.salesforce_domain,
oauth2ClientCredentials=self._messages.Oauth2ClientCredentials(
clientId=args.salesforce_oauth2_client_id,
clientSecret=args.salesforce_oauth2_client_secret,
secretManagerStoredClientSecret=args.salesforce_secret_manager_stored_oauth2_client_secret,
),
)
else:
return self._messages.SalesforceProfile(
domain=args.salesforce_domain,
userCredentials=self._messages.UserCredentials(
username=args.salesforce_username,
password=args.salesforce_password,
secretManagerStoredPassword=args.salesforce_secret_manager_stored_password,
securityToken=args.salesforce_security_token,
secretManagerStoredSecurityToken=args.salesforce_secret_manager_stored_security_token,
),
)
def _GetGCSProfile(self, args, release_track):
# TODO(b/207467120): remove bucket_name arg check.
if release_track == base.ReleaseTrack.BETA:
bucket = args.bucket_name
else:
bucket = args.bucket
gcs_profile = self._messages.GcsProfile(bucket=bucket)
gcs_profile.rootPath = args.root_path if args.root_path else '/'
return gcs_profile
def _GetMongodbProfile(self, args):
"""Returns the MongoDB profile message based on the given args."""
addresses = []
for host_address in args.mongodb_host_addresses:
if args.mongodb_srv_connection_format:
addresses.append(
self._messages.HostAddress(hostname=host_address)
)
else:
hostport = HostPort.Parse(host_address)
addresses.append(
self._messages.HostAddress(
hostname=hostport.host, port=int(hostport.port)
)
)
profile = self._messages.MongodbProfile(
hostAddresses=addresses,
username=args.mongodb_username,
replicaSet=args.mongodb_replica_set,
password=args.mongodb_password,
secretManagerStoredPassword=args.mongodb_secret_manager_stored_password,
)
if (
args.mongodb_direct_connection
and not args.mongodb_standard_connection_format
):
raise exceptions.InvalidArgumentException(
'mongodb-direct-connection',
'mongodb direct connection can only be used with the standard'
' connection format.',
)
if args.mongodb_srv_connection_format:
profile.srvConnectionFormat = {}
if args.mongodb_standard_connection_format:
profile.standardConnectionFormat = (
self._messages.StandardConnectionFormat(
directConnection=args.mongodb_direct_connection
)
)
if args.mongodb_tls:
profile.sslConfig = {}
if args.mongodb_ca_certificate:
profile.sslConfig.caCertificate = args.mongodb_ca_certificate
return profile
def _ParseSslConfig(self, data):
return self._messages.MysqlSslConfig(
clientKey=data.get('client_key'),
clientCertificate=data.get('client_certificate'),
caCertificate=data.get('ca_certificate'))
def _ParseMySqlProfile(self, data):
if not data:
return {}
ssl_config = self._ParseSslConfig(data)
return self._messages.MysqlProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
sslConfig=ssl_config)
def _ParseOracleProfile(self, data):
if not data:
return {}
return self._messages.OracleProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
databaseService=data.get('database_service'))
def _ParsePostgresqlProfile(self, data):
if not data:
return {}
return self._messages.PostgresqlProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
database=data.get('database'))
def _ParseSqlServerProfile(self, data):
if not data:
return {}
return self._messages.SqlServerProfile(
hostname=data.get('hostname'),
port=data.get('port'),
username=data.get('username'),
password=data.get('password'),
database=data.get('database'),
)
def _ParseGCSProfile(self, data):
if not data:
return {}
return self._messages.GcsProfile(
bucket=data.get('bucket_name'), rootPath=data.get('root_path'))
def _GetForwardSshTunnelConnectivity(self, args):
return self._messages.ForwardSshTunnelConnectivity(
hostname=args.forward_ssh_hostname,
port=args.forward_ssh_port,
username=args.forward_ssh_username,
privateKey=args.forward_ssh_private_key,
password=args.forward_ssh_password)
def _GetConnectionProfile(self, cp_type, connection_profile_id, args,
release_track):
"""Returns a connection profile according to type."""
labels = labels_util.ParseCreateArgs(
args, self._messages.ConnectionProfile.LabelsValue)
connection_profile_obj = self._messages.ConnectionProfile(
name=connection_profile_id, labels=labels,
displayName=args.display_name)
if cp_type == 'MYSQL':
connection_profile_obj.mysqlProfile = self._GetMySqlProfile(args)
elif cp_type == 'ORACLE':
connection_profile_obj.oracleProfile = self._GetOracleProfile(args)
elif cp_type == 'POSTGRESQL':
connection_profile_obj.postgresqlProfile = self._GetPostgresqlProfile(
args)
elif cp_type == 'SQLSERVER':
connection_profile_obj.sqlServerProfile = self._GetSqlServerProfile(args)
elif cp_type == 'GOOGLE-CLOUD-STORAGE':
connection_profile_obj.gcsProfile = self._GetGCSProfile(
args, release_track)
elif cp_type == 'BIGQUERY':
connection_profile_obj.bigqueryProfile = self._messages.BigQueryProfile()
elif cp_type == 'SALESFORCE':
connection_profile_obj.salesforceProfile = self._GetSalesforceProfile(
args
)
elif cp_type == 'MONGODB':
connection_profile_obj.mongodbProfile = self._GetMongodbProfile(args)
else:
raise exceptions.InvalidArgumentException(
cp_type,
'The connection profile type {0} is either unknown or not supported'
' yet.'.format(cp_type),
)
# TODO(b/207467120): deprecate BETA client.
if release_track == base.ReleaseTrack.BETA:
private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
else:
private_connectivity_ref = args.CONCEPTS.private_connection.Parse()
if private_connectivity_ref:
connection_profile_obj.privateConnectivity = (
self._messages.PrivateConnectivity(
privateConnection=private_connectivity_ref.RelativeName()
)
)
elif args.forward_ssh_hostname:
connection_profile_obj.forwardSshConnectivity = (
self._GetForwardSshTunnelConnectivity(args)
)
elif args.static_ip_connectivity:
connection_profile_obj.staticServiceIpConnectivity = {}
return connection_profile_obj
def _ParseConnectionProfileObjectFile(
self, connection_profile_object_file, release_track
):
"""Parses a connection-profile-file into the ConnectionProfile message."""
if release_track != base.ReleaseTrack.BETA:
return util.ParseMessageAndValidateSchema(
connection_profile_object_file,
'ConnectionProfile',
self._messages.ConnectionProfile,
)
data = console_io.ReadFromFileOrStdin(
connection_profile_object_file, binary=False)
try:
connection_profile_data = yaml.load(data)
except Exception as e:
raise ds_exceptions.ParseError('Cannot parse YAML:[{0}]'.format(e))
display_name = connection_profile_data.get('display_name')
labels = connection_profile_data.get('labels')
connection_profile_msg = self._messages.ConnectionProfile(
displayName=display_name,
labels=labels)
oracle_profile = self._ParseOracleProfile(
connection_profile_data.get('oracle_profile', {}))
mysql_profile = self._ParseMySqlProfile(
connection_profile_data.get('mysql_profile', {}))
postgresql_profile = self._ParsePostgresqlProfile(
connection_profile_data.get('postgresql_profile', {}))
sqlserver_profile = self._ParseSqlServerProfile(
connection_profile_data.get('sqlserver_profile', {})
)
gcs_profile = self._ParseGCSProfile(
connection_profile_data.get('gcs_profile', {}))
if oracle_profile:
connection_profile_msg.oracleProfile = oracle_profile
elif mysql_profile:
connection_profile_msg.mysqlProfile = mysql_profile
elif postgresql_profile:
connection_profile_msg.postgresqlProfile = postgresql_profile
elif sqlserver_profile:
connection_profile_msg.sqlServerProfile = sqlserver_profile
elif gcs_profile:
connection_profile_msg.gcsProfile = gcs_profile
if 'static_service_ip_connectivity' in connection_profile_data:
connection_profile_msg.staticServiceIpConnectivity = (
connection_profile_data.get('static_service_ip_connectivity')
)
elif 'forward_ssh_connectivity' in connection_profile_data:
connection_profile_msg.forwardSshConnectivity = (
connection_profile_data.get('forward_ssh_connectivity')
)
elif 'private_connectivity' in connection_profile_data:
connection_profile_msg.privateConnectivity = connection_profile_data.get(
'private_connectivity'
)
else:
raise ds_exceptions.ParseError(
'Cannot parse YAML: missing connectivity method.'
)
return connection_profile_msg
def _UpdateForwardSshTunnelConnectivity(
self, connection_profile, args, update_fields
):
"""Updates Forward SSH tunnel connectivity config."""
if args.IsSpecified('forward_ssh_hostname'):
connection_profile.forwardSshConnectivity.hostname = (
args.forward_ssh_hostname
)
update_fields.append('forwardSshConnectivity.hostname')
if args.IsSpecified('forward_ssh_port'):
connection_profile.forwardSshConnectivity.port = args.forward_ssh_port
update_fields.append('forwardSshConnectivity.port')
if args.IsSpecified('forward_ssh_username'):
connection_profile.forwardSshConnectivity.username = (
args.forward_ssh_username
)
update_fields.append('forwardSshConnectivity.username')
if args.IsSpecified('forward_ssh_private_key'):
connection_profile.forwardSshConnectivity.privateKey = (
args.forward_ssh_private_key
)
update_fields.append('forwardSshConnectivity.privateKey')
if args.IsSpecified('forward_ssh_password'):
connection_profile.forwardSshConnectivity.privateKey = (
args.forward_ssh_password
)
update_fields.append('forwardSshConnectivity.password')
def _UpdateGCSProfile(
self, connection_profile, release_track, args, update_fields
):
"""Updates GOOGLE CLOUD STORAGE connection profile."""
# TODO(b/207467120): remove bucket_name arg check.
if release_track == base.ReleaseTrack.BETA and args.IsSpecified(
'bucket_name'
):
connection_profile.gcsProfile.bucket = args.bucket_name
update_fields.append('gcsProfile.bucket')
if release_track == base.ReleaseTrack.GA and args.IsSpecified('bucket'):
connection_profile.gcsProfile.bucket = args.bucket
update_fields.append('gcsProfile.bucket')
if args.IsSpecified('root_path'):
connection_profile.gcsProfile.rootPath = args.root_path
update_fields.append('gcsProfile.rootPath')
def _UpdateOracleProfile(self,
connection_profile,
args,
update_fields):
"""Updates Oracle connection profile."""
if args.IsSpecified('oracle_hostname'):
connection_profile.oracleProfile.hostname = args.oracle_hostname
update_fields.append('oracleProfile.hostname')
if args.IsSpecified('oracle_port'):
connection_profile.oracleProfile.port = args.oracle_port
update_fields.append('oracleProfile.port')
if args.IsSpecified('oracle_username'):
connection_profile.oracleProfile.username = args.oracle_username
update_fields.append('oracleProfile.username')
if args.IsSpecified('oracle_password') or args.IsSpecified(
'oracle_secret_manager_stored_password'
):
connection_profile.oracleProfile.password = args.oracle_password
connection_profile.oracleProfile.secretManagerStoredPassword = (
args.oracle_secret_manager_stored_password
)
update_fields.append('oracleProfile.password')
update_fields.append('oracleProfile.secretManagerStoredPassword')
if args.IsSpecified('database_service'):
connection_profile.oracleProfile.databaseService = args.database_service
update_fields.append('oracleProfile.databaseService')
def _UpdateMysqlSslConfig(self, connection_profile, args, update_fields):
"""Updates Mysql SSL config."""
if args.IsSpecified('client_key'):
connection_profile.mysqlProfile.sslConfig.clientKey = args.client_key
update_fields.append('mysqlProfile.sslConfig.clientKey')
if args.IsSpecified('client_certificate'):
connection_profile.mysqlProfile.sslConfig.clientCertificate = (
args.client_certificate
)
update_fields.append('mysqlProfile.sslConfig.clientCertificate')
if args.IsSpecified('ca_certificate'):
connection_profile.mysqlProfile.sslConfig.caCertificate = (
args.ca_certificate
)
update_fields.append('mysqlProfile.sslConfig.caCertificate')
def _UpdateMySqlProfile(self, connection_profile, args, update_fields):
"""Updates MySQL connection profile."""
if args.IsSpecified('mysql_hostname'):
connection_profile.mysqlProfile.hostname = args.mysql_hostname
update_fields.append('mysqlProfile.hostname')
if args.IsSpecified('mysql_port'):
connection_profile.mysqlProfile.port = args.mysql_port
update_fields.append('mysqlProfile.port')
if args.IsSpecified('mysql_username'):
connection_profile.mysqlProfile.username = args.mysql_username
update_fields.append('mysqlProfile.username')
if args.IsSpecified('mysql_password') or args.IsSpecified(
'mysql_secret_manager_stored_password'
):
connection_profile.mysqlProfile.password = args.mysql_password
connection_profile.mysqlProfile.secretManagerStoredPassword = (
args.mysql_secret_manager_stored_password
)
update_fields.append('mysqlProfile.password')
update_fields.append('mysqlProfile.secretManagerStoredPassword')
self._UpdateMysqlSslConfig(connection_profile, args, update_fields)
def _UpdatePostgresqlSslConfig(self, connection_profile, args, update_fields):
"""Updates Postgresql SSL config."""
if args.IsSpecified('postgresql_client_certificate'):
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate = (
args.postgresql_client_certificate
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.clientCertificate'
)
if args.IsSpecified('postgresql_client_key'):
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.clientKey = (
args.postgresql_client_key
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.clientKey'
)
if args.IsSpecified('postgresql_ca_certificate'):
if connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification:
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate = (
args.postgresql_ca_certificate
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.caCertificate'
)
else:
connection_profile.postgresqlProfile.sslConfig.serverVerification.caCertificate = (
args.postgresql_ca_certificate
)
update_fields.append(
'postgresqlProfile.sslConfig.serverVerification.caCertificate'
)
if args.IsSpecified('postgresql_server_certificate_hostname'):
if (
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification
):
connection_profile.postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname = (
args.postgresql_server_certificate_hostname
)
update_fields.append(
'postgresqlProfile.sslConfig.serverAndClientVerification.serverCertificateHostname'
)
else:
connection_profile.postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname = (
args.postgresql_server_certificate_hostname
)
update_fields.append(
'postgresqlProfile.sslConfig.serverVerification.serverCertificateHostname'
)
def _UpdatePostgresqlProfile(self, connection_profile, args, update_fields):
"""Updates Postgresql connection profile."""
if args.IsSpecified('postgresql_hostname'):
connection_profile.postgresqlProfile.hostname = args.postgresql_hostname
update_fields.append('postgresqlProfile.hostname')
if args.IsSpecified('postgresql_port'):
connection_profile.postgresqlProfile.port = args.postgresql_port
update_fields.append('postgresqlProfile.port')
if args.IsSpecified('postgresql_username'):
connection_profile.postgresqlProfile.username = args.postgresql_username
update_fields.append('postgresqlProfile.username')
if args.IsSpecified('postgresql_password') or args.IsSpecified(
'postgresql_secret_manager_stored_password'
):
connection_profile.postgresqlProfile.password = args.postgresql_password
connection_profile.postgresqlProfile.secretManagerStoredPassword = (
args.postgresql_secret_manager_stored_password
)
update_fields.append('postgresqlProfile.password')
update_fields.append('postgresqlProfile.secretManagerStoredPassword')
if args.IsSpecified('postgresql_database'):
connection_profile.postgresqlProfile.database = args.postgresql_database
update_fields.append('postgresqlProfile.database')
self._UpdatePostgresqlSslConfig(connection_profile, args, update_fields)
def _UpdateSqlServerProfile(self, connection_profile, args, update_fields):
"""Updates SqlServer connection profile."""
if args.IsSpecified('sqlserver_hostname'):
connection_profile.sqlServerProfile.hostname = args.sqlserver_hostname
update_fields.append('sqlServerProfile.hostname')
if args.IsSpecified('sqlserver_port'):
connection_profile.sqlServerProfile.port = args.sqlserver_port
update_fields.append('sqlServerProfile.port')
if args.IsSpecified('sqlserver_username'):
connection_profile.sqlServerProfile.username = args.sqlserver_username
update_fields.append('sqlServerProfile.username')
if args.IsSpecified('sqlserver_password') or args.IsSpecified(
'sqlserver_secret_manager_stored_password'
):
connection_profile.sqlServerProfile.password = args.sqlserver_password
connection_profile.sqlServerProfile.secretManagerStoredPassword = (
args.sqlserver_secret_manager_stored_password
)
update_fields.append('sqlServerProfile.password')
update_fields.append('sqlServerProfile.secretManagerStoredPassword')
if args.IsSpecified('sqlserver_database'):
connection_profile.sqlServerProfile.database = args.sqlserver_database
update_fields.append('sqlServerProfile.database')
def _UpdateSalesforceProfile(self, connection_profile, args, update_fields):
"""Updates Salesforce connection profile."""
if args.IsSpecified('salesforce_domain'):
connection_profile.salesforceProfile.domain = args.salesforce_domain
update_fields.append('salesforceProfile.domain')
if args.IsSpecified('salesforce_username'):
connection_profile.salesforceProfile.userCredentials.username = (
args.salesforce_username
)
update_fields.append('salesforceProfile.userCredentials.username')
if args.IsSpecified('salesforce_password') or args.IsSpecified(
'salesforce_secret_manager_stored_password'
):
connection_profile.salesforceProfile.userCredentials.password = (
args.salesforce_password
)
connection_profile.salesforceProfile.userCredentials.secretManagerStoredPassword = (
args.salesforce_secret_manager_stored_password
)
update_fields.append('salesforceProfile.userCredentials.password')
update_fields.append(
'salesforceProfile.userCredentials.secretManagerStoredPassword'
)
if args.IsSpecified('salesforce_security_token') or args.IsSpecified(
'salesforce_secret_manager_stored_security_token'
):
connection_profile.salesforceProfile.userCredentials.securityToken = (
args.salesforce_security_token
)
connection_profile.salesforceProfile.userCredentials.secretManagerStoredSecurityToken = (
args.salesforce_secret_manager_stored_security_token
)
update_fields.append('salesforceProfile.userCredentials.securityToken')
update_fields.append(
'salesforceProfile.userCredentials.secretManagerStoredSecurityToken'
)
if args.IsSpecified('salesforce_oauth2_client_id'):
connection_profile.salesforceProfile.oauth2ClientCredentials.clientId = (
args.salesforce_oauth2_client_id
)
update_fields.append('salesforceProfile.oauth2ClientCredentials.clientId')
if args.IsSpecified('salesforce_oauth2_client_secret') or args.IsSpecified(
'salesforce_secret_manager_stored_oauth2_client_secret'
):
connection_profile.salesforceProfile.oauth2ClientCredentials.clientSecret = (
args.salesforce_oauth2_client_secret
)
connection_profile.salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret = (
args.salesforce_secret_manager_stored_oauth2_client_secret
)
update_fields.append(
'salesforceProfile.oauth2ClientCredentials.clientSecret'
)
update_fields.append(
'salesforceProfile.oauth2ClientCredentials.secretManagerStoredClientSecret'
)
def _UpdateMongodbProfile(self, connection_profile, args, update_fields):
"""Updates MongoDB connection profile."""
if args.IsSpecified('mongodb_host_addresses'):
addresses = []
for host_address in args.mongodb_host_addresses:
if args.mongodb_srv_connection_format:
addresses.append(
self._messages.HostAddress(hostname=host_address)
)
else:
hostname, port = host_address.split(':')
addresses.append(
self._messages.HostAddress(hostname=hostname, port=int(port))
)
connection_profile.mongodbProfile.hostAddresses = addresses
update_fields.append('monogodbProfile.hostAddresses')
if args.IsSpecified('mongodb_replica_set'):
connection_profile.mongodbProfile.replicaSet = args.mongodb_replica_set
update_fields.append('mongodbProfile.replicaSet')
if args.IsSpecified('mongodb_srv_connection_format') or args.IsSpecified(
'mongodb_standard_connection_format'
):
if args.mongodb_srv_connection_format:
connection_profile.mongodbProfile.srvConnectionFormat = {}
if args.mongodb_standard_connection_format:
connection_profile.mongodbProfile.standardConnectionFormat = {}
update_fields.append('mongodbProfile.srvConnectionFormat')
update_fields.append('mongodbProfile.standardConnectionFormat')
if args.IsSpecified('mongodb_username'):
connection_profile.mongodbProfile.username = args.mongodb_username
update_fields.append('mongodbProfile.username')
if args.IsSpecified('mongodb_password') or args.IsSpecified(
'mongodb_secret_manager_stored_password'
):
connection_profile.mongodbProfile.password = args.mongodb_password
connection_profile.mongodbProfile.secretManagerStoredPassword = (
args.mongodb_secret_manager_stored_password
)
update_fields.append('mongodbProfile.password')
update_fields.append('mongodbProfile.secretManagerStoredPassword')
def _GetExistingConnectionProfile(self, name):
get_req = (
self._messages.DatastreamProjectsLocationsConnectionProfilesGetRequest(
name=name
)
)
return self._service.Get(get_req)
def _UpdateLabels(self, connection_profile, args):
"""Updates labels of the connection profile."""
add_labels = labels_util.GetUpdateLabelsDictFromArgs(args)
remove_labels = labels_util.GetRemoveLabelsListFromArgs(args)
value_type = self._messages.ConnectionProfile.LabelsValue
update_result = labels_util.Diff(
additions=add_labels,
subtractions=remove_labels,
clear=args.clear_labels
).Apply(value_type, connection_profile.labels)
if update_result.needs_update:
connection_profile.labels = update_result.labels
def _GetUpdatedConnectionProfile(self, connection_profile, cp_type,
release_track, args):
"""Returns updated connection profile and list of updated fields."""
update_fields = []
if args.IsSpecified('display_name'):
connection_profile.displayName = args.display_name
update_fields.append('displayName')
if cp_type == 'MYSQL':
self._UpdateMySqlProfile(
connection_profile, args, update_fields)
elif cp_type == 'ORACLE':
self._UpdateOracleProfile(connection_profile, args, update_fields)
elif cp_type == 'POSTGRESQL':
self._UpdatePostgresqlProfile(connection_profile, args, update_fields)
elif cp_type == 'SQLSERVER':
self._UpdateSqlServerProfile(connection_profile, args, update_fields)
elif cp_type == 'SALESFORCE':
self._UpdateSalesforceProfile(connection_profile, args, update_fields)
elif cp_type == 'GOOGLE-CLOUD-STORAGE':
self._UpdateGCSProfile(
connection_profile, release_track, args, update_fields
)
elif cp_type == 'BIGQUERY':
# There are currently no parameters that can be updated in a bigquery CP.
pass
elif cp_type == 'MONGODB':
self._UpdateMongodbProfile(connection_profile, args, update_fields)
else:
raise exceptions.InvalidArgumentException(
cp_type,
'The connection profile type {0} is either unknown or not supported'
' yet.'.format(cp_type),
)
# TODO(b/207467120): deprecate BETA client.
if release_track == base.ReleaseTrack.BETA:
private_connectivity_ref = args.CONCEPTS.private_connection_name.Parse()
else:
private_connectivity_ref = args.CONCEPTS.private_connection.Parse()
if private_connectivity_ref:
connection_profile.privateConnectivity = (
self._messages.PrivateConnectivity(
privateConnectionName=private_connectivity_ref.RelativeName()
)
)
update_fields.append('privateConnectivity')
elif args.forward_ssh_hostname:
self._UpdateForwardSshTunnelConnectivity(
connection_profile, args, update_fields
)
elif args.static_ip_connectivity:
connection_profile.staticServiceIpConnectivity = {}
update_fields.append('staticServiceIpConnectivity')
self._UpdateLabels(connection_profile, args)
return connection_profile, update_fields
def Create(self,
parent_ref,
connection_profile_id,
cp_type,
release_track,
args=None):
"""Creates a connection profile.
Args:
parent_ref: a Resource reference to a parent datastream.projects.locations
resource for this connection profile.
connection_profile_id: str, the name of the resource to create.
cp_type: str, the type of the connection profile ('MYSQL', ''
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for creating the connection profile.
"""
self._ValidateArgs(args)
connection_profile = self._GetConnectionProfile(cp_type,
connection_profile_id, args,
release_track)
# TODO(b/207467120): only use flags from args.
force = False
if release_track == base.ReleaseTrack.BETA or args.force:
force = True
request_id = util.GenerateRequestId()
create_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesCreateRequest
)
create_req = create_req_type(
connectionProfile=connection_profile,
connectionProfileId=connection_profile.name,
parent=parent_ref,
requestId=request_id,
force=force)
return self._service.Create(create_req)
def Update(self, name, cp_type, release_track, args=None):
"""Updates a connection profile.
Args:
name: str, the reference of the connection profile to
update.
cp_type: str, the type of the connection profile ('MYSQL', 'ORACLE')
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was
invoked with.
Returns:
Operation: the operation for updating the connection profile.
"""
self._ValidateArgs(args)
current_cp = self._GetExistingConnectionProfile(name)
updated_cp, update_fields = self._GetUpdatedConnectionProfile(
current_cp, cp_type, release_track, args)
# TODO(b/207467120): only use flags from args.
force = False
if release_track == base.ReleaseTrack.BETA or args.force:
force = True
request_id = util.GenerateRequestId()
update_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesPatchRequest
)
update_req = update_req_type(
connectionProfile=updated_cp,
name=updated_cp.name,
updateMask=','.join(update_fields),
requestId=request_id,
force=force,
)
return self._service.Patch(update_req)
def List(self, project_id, args):
"""Get the list of connection profiles in a project.
Args:
project_id: The project ID to retrieve
args: parsed command line arguments
Returns:
An iterator over all the matching connection profiles.
"""
location_ref = self._resource_parser.Create(
'datastream.projects.locations',
projectsId=project_id,
locationsId=args.location,
)
list_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesListRequest
)
list_req = list_req_type(
parent=location_ref.RelativeName(),
filter=args.filter,
orderBy=','.join(args.sort_by) if args.sort_by else None,
)
return list_pager.YieldFromList(
service=self._client.projects_locations_connectionProfiles,
request=list_req,
limit=args.limit,
batch_size=args.page_size,
field='connectionProfiles',
batch_size_attribute='pageSize')
def Discover(self, parent_ref, release_track, args):
"""Discover a connection profile.
Args:
parent_ref: a Resource reference to a parent datastream.projects.locations
resource for this connection profile.
release_track: Some arguments are added based on the command release
track.
args: argparse.Namespace, The arguments that this command was invoked
with.
Returns:
Operation: the operation for discovering the connection profile.
"""
request = self._messages.DiscoverConnectionProfileRequest()
if args.connection_profile_name:
connection_profile_ref = args.CONCEPTS.connection_profile_name.Parse()
request.connectionProfileName = connection_profile_ref.RelativeName()
elif args.connection_profile_object_file:
request.connectionProfile = self._ParseConnectionProfileObjectFile(
args.connection_profile_object_file, release_track
)
if args.recursive or args.full_hierarchy:
request.fullHierarchy = True
elif args.recursive_depth:
request.hierarchyDepth = (int)(args.recursive_depth)
elif args.hierarchy_depth:
request.hierarchyDepth = (int)(args.hierarchy_depth)
else:
request.fullHierarchy = False
if args.mysql_rdbms_file:
request.mysqlRdbms = util.ParseMysqlRdbmsFile(self._messages,
args.mysql_rdbms_file,
release_track)
elif args.oracle_rdbms_file:
request.oracleRdbms = util.ParseOracleRdbmsFile(self._messages,
args.oracle_rdbms_file,
release_track)
elif args.postgresql_rdbms_file:
request.postgresqlRdbms = util.ParsePostgresqlRdbmsFile(
self._messages, args.postgresql_rdbms_file)
elif args.sqlserver_rdbms_file:
request.sqlServerRdbms = util.ParseSqlServerRdbmsFile(
self._messages, args.sqlserver_rdbms_file
)
discover_req_type = (
self._messages.DatastreamProjectsLocationsConnectionProfilesDiscoverRequest
)
discover_req = discover_req_type(
discoverConnectionProfileRequest=request, parent=parent_ref)
return self._service.Discover(discover_req)
def GetUri(self, name):
"""Get the URL string for a connection profile.
Args:
name: connection profile's full name.
Returns:
URL of the connection profile resource
"""
uri = self._resource_parser.ParseRelativeName(
name, collection='datastream.projects.locations.connectionProfiles')
return uri.SelfLink()