File: //snap/google-cloud-cli/396/platform/bq/clients/client_connection.py
#!/usr/bin/env python
"""The BigQuery CLI connection client library."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from typing import Any, Dict, List, Mapping, Optional
from googleapiclient import discovery
import inflection
from utils import bq_api_utils
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
Service = bq_api_utils.Service
# Data Transfer Service Authorization Info
AUTHORIZATION_CODE = 'authorization_code'
VERSION_INFO = 'version_info'
# Valid proto field name regex.
_VALID_FIELD_NAME_REGEXP = r'[0-9A-Za-z_]+'
# Connection field mask paths pointing to map keys.
_MAP_KEY_PATHS = [
'configuration.parameters',
'configuration.authentication.parameters',
]
_AUTH_PROFILE_ID_PATH = 'configuration.authentication.profile_id'
_AUTH_PATH = 'configuration.authentication'
def GetConnection(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
):
"""Gets connection with the given connection reference.
Arguments:
client: the client used to make the request.
reference: Connection to get.
Returns:
Connection object with the given id.
"""
return (
client.projects()
.locations()
.connections()
.get(name=reference.path())
.execute()
)
def CreateConnection(
client: discovery.Resource,
project_id: str,
location: str,
connection_type: str, # Actually a CONNECTION_TYPE_TO_PROPERTY_MAP key.
properties: str,
connection_credential: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
connection_id: Optional[str] = None,
kms_key_name: Optional[str] = None,
connector_configuration: Optional[str] = None,
):
"""Create a connection with the given connection reference.
Arguments:
client: the client used to make the request.
project_id: Project ID.
location: Location of connection.
connection_type: Type of connection, allowed values: ['CLOUD_SQL']
properties: Connection properties in JSON format.
connection_credential: Connection credentials in JSON format.
display_name: Friendly name for the connection.
description: Description of the connection.
connection_id: Optional connection ID.
kms_key_name: Optional KMS key name.
connector_configuration: Optional configuration for connector.
Returns:
Connection object that was created.
"""
connection = {}
if display_name:
connection['friendlyName'] = display_name
if description:
connection['description'] = description
if kms_key_name:
connection['kmsKeyName'] = kms_key_name
property_name = bq_processor_utils.CONNECTION_TYPE_TO_PROPERTY_MAP.get(
connection_type
)
if property_name:
connection[property_name] = bq_processor_utils.ParseJson(properties)
if connection_credential:
if isinstance(connection[property_name], Mapping):
connection[property_name]['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
else:
raise ValueError('The `properties` were not a dictionary.')
elif connector_configuration:
connection['configuration'] = bq_processor_utils.ParseJson(
connector_configuration
)
else:
error = (
'connection_type %s is unsupported or connector_configuration is not'
' specified' % connection_type
)
raise ValueError(error)
parent = 'projects/%s/locations/%s' % (project_id, location)
return (
client.projects()
.locations()
.connections()
.create(parent=parent, connectionId=connection_id, body=connection)
.execute()
)
def UpdateConnection(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
connection_type: Optional[
str
] = None, # Actually a CONNECTION_TYPE_TO_PROPERTY_MAP key.
properties: Optional[str] = None,
connection_credential: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
kms_key_name: Optional[str] = None,
connector_configuration: Optional[str] = None,
):
"""Update connection with the given connection reference.
Arguments:
client: the client used to make the request.
reference: Connection to update
connection_type: Type of connection, allowed values: ['CLOUD_SQL']
properties: Connection properties
connection_credential: Connection credentials in JSON format.
display_name: Friendly name for the connection
description: Description of the connection
kms_key_name: Optional KMS key name.
connector_configuration: Optional configuration for connector
Raises:
bq_error.BigqueryClientError: The connection type is not defined
when updating
connection_credential or properties.
Returns:
Connection object that was created.
"""
if (connection_credential or properties) and not connection_type:
raise bq_error.BigqueryClientError(
'connection_type is required when updating connection_credential or'
' properties'
)
connection = {}
update_mask = []
if display_name:
connection['friendlyName'] = display_name
update_mask.append('friendlyName')
if description:
connection['description'] = description
update_mask.append('description')
if kms_key_name is not None:
update_mask.append('kms_key_name')
if kms_key_name:
connection['kmsKeyName'] = kms_key_name
if connection_type == 'CLOUD_SQL':
if properties:
cloudsql_properties = bq_processor_utils.ParseJson(properties)
connection['cloudSql'] = cloudsql_properties
update_mask.extend(
_GetUpdateMask(connection_type.lower(), cloudsql_properties)
)
else:
connection['cloudSql'] = {}
if connection_credential:
connection['cloudSql']['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
update_mask.append('cloudSql.credential')
elif connection_type == 'AWS':
if properties:
aws_properties = bq_processor_utils.ParseJson(properties)
connection['aws'] = aws_properties
if aws_properties.get('crossAccountRole') and aws_properties[
'crossAccountRole'
].get('iamRoleId'):
update_mask.append('aws.crossAccountRole.iamRoleId')
if aws_properties.get('accessRole') and aws_properties['accessRole'].get(
'iamRoleId'
):
update_mask.append('aws.access_role.iam_role_id')
else:
connection['aws'] = {}
if connection_credential:
connection['aws']['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
update_mask.append('aws.credential')
elif connection_type == 'Azure':
if properties:
azure_properties = bq_processor_utils.ParseJson(properties)
connection['azure'] = azure_properties
if azure_properties.get('customerTenantId'):
update_mask.append('azure.customer_tenant_id')
if azure_properties.get('federatedApplicationClientId'):
update_mask.append('azure.federated_application_client_id')
elif connection_type == 'SQL_DATA_SOURCE':
if properties:
sql_data_source_properties = bq_processor_utils.ParseJson(properties)
connection['sqlDataSource'] = sql_data_source_properties
update_mask.extend(
_GetUpdateMask(connection_type.lower(), sql_data_source_properties)
)
else:
connection['sqlDataSource'] = {}
if connection_credential:
connection['sqlDataSource']['credential'] = bq_processor_utils.ParseJson(
connection_credential
)
update_mask.append('sqlDataSource.credential')
elif connection_type == 'CLOUD_SPANNER':
if properties:
cloudspanner_properties = bq_processor_utils.ParseJson(properties)
connection['cloudSpanner'] = cloudspanner_properties
update_mask.extend(
_GetUpdateMask(connection_type.lower(), cloudspanner_properties)
)
else:
connection['cloudSpanner'] = {}
elif connection_type == 'SPARK':
if properties:
spark_properties = bq_processor_utils.ParseJson(properties)
connection['spark'] = spark_properties
if 'sparkHistoryServerConfig' in spark_properties:
update_mask.append('spark.spark_history_server_config')
if 'metastoreServiceConfig' in spark_properties:
update_mask.append('spark.metastore_service_config')
else:
connection['spark'] = {}
elif connector_configuration:
connection['configuration'] = bq_processor_utils.ParseJson(
connector_configuration
)
update_mask.extend(
_GetUpdateMaskRecursively('configuration', connection['configuration'])
)
if _AUTH_PROFILE_ID_PATH in update_mask and _AUTH_PATH not in update_mask:
update_mask.append(_AUTH_PATH)
return (
client.projects()
.locations()
.connections()
.patch(
name=reference.path(),
updateMask=','.join(update_mask),
body=connection,
)
.execute()
)
def _GetUpdateMask(
base_path: str, json_properties: Dict[str, Any]
) -> List[str]:
"""Creates an update mask from json_properties.
Arguments:
base_path: 'cloud_sql'
json_properties: { 'host': ... , 'instanceId': ... }
Returns:
list of paths in snake case:
mask = ['cloud_sql.host', 'cloud_sql.instance_id']
"""
return [
base_path + '.' + inflection.underscore(json_property)
for json_property in json_properties
]
def _EscapeIfRequired(prefix: str, name: str) -> str:
"""Escapes name if it points to a map key or converts it to snake case.
If name points to a map key:
1. Do not change the name.
2. Escape name with backticks if it is not a valid proto field name.
Args:
prefix: field mask prefix to check if name points to a map key.
name: name of the field.
Returns:
escaped name
"""
if prefix in _MAP_KEY_PATHS:
return (
name
if re.fullmatch(_VALID_FIELD_NAME_REGEXP, name)
else ('`' + name + '`')
)
# Otherwise, convert name to snake case
return inflection.underscore(name)
def _GetUpdateMaskRecursively(
prefix: str, json_value: Dict[str, Any]
) -> List[str]:
"""Recursively traverses json_value and returns a list of update mask paths.
Args:
prefix: current prefix of the json value.
json_value: value to traverse.
Returns:
a field mask containing all the set paths in the json value.
"""
if not isinstance(json_value, dict) or not json_value:
return [prefix]
result = []
for name in json_value:
new_prefix = prefix + '.' + _EscapeIfRequired(prefix, name)
new_json_value = json_value.get(name)
result.extend(_GetUpdateMaskRecursively(new_prefix, new_json_value))
return result
def DeleteConnection(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
):
"""Delete a connection with the given connection reference.
Arguments:
client: the client used to make the request.
reference: Connection to delete.
"""
client.projects().locations().connections().delete(
name=reference.path()
).execute()
def ListConnections(
client: discovery.Resource,
project_id: str,
location: str,
max_results: int,
page_token: Optional[str],
):
"""List connections in the project and location for the given reference.
Arguments:
client: the client used to make the request.
project_id: Project ID.
location: Location.
max_results: Number of results to show.
page_token: Token to retrieve the next page of results.
Returns:
List of connection objects
"""
parent = 'projects/%s/locations/%s' % (project_id, location)
return (
client.projects()
.locations()
.connections()
.list(parent=parent, pageToken=page_token, pageSize=max_results)
.execute()
)
def SetConnectionIAMPolicy(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
policy: str,
):
"""Sets IAM policy for the given connection resource.
Arguments:
client: the client used to make the request.
reference: the ConnectionReference for the connection resource.
policy: The policy string in JSON format.
Returns:
The updated IAM policy attached to the given connection resource.
Raises:
BigqueryTypeError: if reference is not a ConnectionReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ConnectionReference,
method='SetConnectionIAMPolicy',
)
return (
client.projects()
.locations()
.connections()
.setIamPolicy(resource=reference.path(), body={'policy': policy})
.execute()
)
def GetConnectionIAMPolicy(
client: discovery.Resource,
reference: bq_id_utils.ApiClientHelper.ConnectionReference,
):
"""Gets IAM policy for the given connection resource.
Arguments:
client: the client used to make the request.
reference: the ConnectionReference for the connection resource.
Returns:
The IAM policy attached to the given connection resource.
Raises:
BigqueryTypeError: if reference is not a ConnectionReference.
"""
bq_id_utils.typecheck(
reference,
bq_id_utils.ApiClientHelper.ConnectionReference,
method='GetConnectionIAMPolicy',
)
return (
client.projects()
.locations()
.connections()
.getIamPolicy(resource=reference.path())
.execute()
)