File: //snap/google-cloud-cli/current/platform/bq/frontend/command_truncate.py
#!/usr/bin/env python
"""The BigQuery CLI truncate command."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Optional
from absl import app
from absl import flags
import bq_flags
from clients import client_job
from clients import client_table
from clients import utils as bq_client_utils
from frontend import bigquery_command
from frontend import bq_cached_client
from utils import bq_error
from utils import bq_id_utils
from utils import bq_processor_utils
# These aren't relevant for user-facing docstrings:
# pylint: disable=g-doc-return-or-yield
# pylint: disable=g-doc-args
class Truncate(bigquery_command.BigqueryCmd): # pylint: disable=missing-docstring
usage = """bq truncate project_id:dataset[.table] [--timestamp] [--dry_run] [--overwrite] [--skip_fully_replicated_tables]
"""
def __init__(self, name: str, fv: flags.FlagValues):
super(Truncate, self).__init__(name, fv)
flags.DEFINE_integer(
'timestamp',
None,
'Optional timestamp to which table(s) will be truncated. Specified as '
'milliseconds since epoch.',
short_name='t',
flag_values=fv,
)
flags.DEFINE_boolean(
'dry_run',
None,
'No-op that simply prints out information and the recommended '
'timestamp without modifying tables or datasets.',
flag_values=fv,
)
flags.DEFINE_boolean(
'overwrite',
False,
'Overwrite existing tables. Otherwise timestamp will be appended to '
'all output table names.',
flag_values=fv,
)
flags.DEFINE_boolean(
'skip_fully_replicated_tables',
True,
'Skip tables that are fully replicated (synced) and do not need to be '
'truncated back to a point in time. This could result in datasets that '
'have tables synchronized to different points in time, but will '
'require less data to be re-loaded',
short_name='s',
flag_values=fv,
)
self._ProcessCommandRc(fv)
def RunWithArgs(self, identifier: str = '') -> Optional[int]:
# pylint: disable=g-doc-exception
"""Truncates table/dataset/project to a particular timestamp.
Examples:
bq truncate project_id:dataset
bq truncate --overwrite project_id:dataset --timestamp 123456789
bq truncate --skip_fully_replicated_tables=false project_id:dataset
"""
client = bq_cached_client.Client.Get()
if identifier:
reference = bq_client_utils.GetReference(
id_fallbacks=client, identifier=identifier.strip()
)
else:
raise app.UsageError('Must specify one of project, dataset or table')
self.truncated_table_count = 0
self.skipped_table_count = 0
self.failed_table_count = 0
status = []
if self.timestamp and not self.dry_run:
print(
'Truncating to user specified timestamp %s.(Not skipping fully'
' replicated tables.)'
% self.timestamp
)
if isinstance(reference, bq_id_utils.ApiClientHelper.TableReference):
all_tables = [reference]
else:
if isinstance(reference, bq_id_utils.ApiClientHelper.DatasetReference):
all_tables = list(
map(
lambda x: bq_client_utils.GetReference(
id_fallbacks=client, identifier=x['id']
),
client_table.list_tables(
apiclient=client.apiclient,
reference=reference,
max_results=1000 * 1000,
),
)
)
for a_table in all_tables:
try:
status.append(
self._TruncateTable(a_table, str(self.timestamp), False)
)
except bq_error.BigqueryError as e:
print(e)
status.append((self._formatOutputString(a_table, 'Failed')))
self.failed_table_count += 1
else:
if isinstance(reference, bq_id_utils.ApiClientHelper.TableReference):
all_table_infos = self._GetTableInfo(reference)
else:
if isinstance(reference, bq_id_utils.ApiClientHelper.DatasetReference):
all_table_infos = self._GetTableInfosFromDataset(reference)
try:
recovery_timestamp = min(
list(map(self._GetRecoveryTimestamp, all_table_infos))
)
except (ValueError, bq_error.BigqueryTypeError):
recovery_timestamp = None
# Error out if we can't figure out a recovery timestamp
# This can happen in following cases:
# 1. No multi_site_info present for a table because no commit has been
# made to the table.
# 2. No secondary site is present.
if not recovery_timestamp:
raise app.UsageError(
'Unable to figure out a recovery timestamp for %s. Exiting.'
% reference
)
print('Recommended timestamp to truncate to is %s' % recovery_timestamp)
for a_table in all_table_infos:
if not hasattr(reference, 'datasetId'):
raise AttributeError('Missing `datasetId` on reference.')
try:
table_reference = bq_id_utils.ApiClientHelper.TableReference.Create(
projectId=reference.projectId,
datasetId=reference.datasetId,
tableId=a_table['name'],
)
status.append(
self._TruncateTable(
table_reference,
str(recovery_timestamp),
a_table['fully_replicated'],
)
)
except bq_error.BigqueryError as e:
print(e)
status.append((self._formatOutputString(table_reference, 'Failed')))
self.failed_table_count += 1
print(
'%s tables truncated, %s tables failed to truncate, %s tables skipped'
% (
self.truncated_table_count,
self.failed_table_count,
self.skipped_table_count,
)
)
print(*status, sep='\n')
def _GetTableInfosFromDataset(
self, dataset_reference: bq_id_utils.ApiClientHelper.DatasetReference
):
# Find minimum of second maximum(latest_replicated_time) for all tables in
# the dataset and if they are fully replicated.
recovery_timestamp_for_dataset_query = ("""SELECT
TABLE_NAME,
UNIX_MILLIS(replicated_time_at_remote_site),
CASE
WHEN last_update_time <= min_latest_replicated_time THEN TRUE
ELSE
FALSE
END
AS fully_replicated
FROM (
SELECT
TABLE_NAME,
multi_site_info.last_update_time,
ARRAY_AGG(site_info.latest_replicated_time
ORDER BY
latest_replicated_time DESC)[safe_OFFSET(1)] AS replicated_time_at_remote_site,
ARRAY_AGG(site_info.latest_replicated_time
ORDER BY
latest_replicated_time ASC)[safe_OFFSET(0)] AS min_latest_replicated_time
FROM
%s.INFORMATION_SCHEMA.TABLES t,
t.multi_site_info.site_info
GROUP BY
1,
2)""") % dataset_reference.datasetId
return self._ReadTableInfo(
recovery_timestamp_for_dataset_query, 1000 * 1000
)
def _GetTableInfo(
self, table_reference: bq_id_utils.ApiClientHelper.TableReference
):
# Find second maximum of latest_replicated_time across all sites for this
# table and if the table is fully replicated
recovery_timestamp_for_table_query = ("""SELECT
TABLE_NAME,
UNIX_MILLIS(replicated_time_at_remote_site),
CASE
WHEN last_update_time <= min_latest_replicated_time THEN TRUE
ELSE
FALSE
END
AS fully_replicated
FROM (
SELECT
TABLE_NAME,
multi_site_info.last_update_time,
ARRAY_AGG(site_info.latest_replicated_time
ORDER BY
latest_replicated_time DESC)[safe_OFFSET(1)] AS replicated_time_at_remote_site,
ARRAY_AGG(site_info.latest_replicated_time
ORDER BY
latest_replicated_time ASC)[safe_OFFSET(0)] AS min_latest_replicated_time
FROM
%s.INFORMATION_SCHEMA.TABLES t,
t.multi_site_info.site_info
WHERE
TABLE_NAME = '%s'
GROUP BY
1,
2 )""") % (table_reference.datasetId, table_reference.tableId)
return self._ReadTableInfo(recovery_timestamp_for_table_query, row_count=1)
def _GetRecoveryTimestamp(self, table_info) -> Optional[int]:
return (
int(table_info['recovery_timestamp'])
if table_info['recovery_timestamp']
else None
)
def _ReadTableInfo(self, query: str, row_count: int):
client = bq_cached_client.Client.Get()
try:
job = client_job.Query(client, query, use_legacy_sql=False)
except bq_error.BigqueryError as e:
# TODO(b/324243535): Correct this typing.
# pytype: disable=attribute-error
if 'Name multi_site_info not found' in e.error['message']:
# pytype: enable=attribute-error
raise app.UsageError(
'This functionality is not enabled for the current project.'
)
else:
raise e
all_table_infos = []
if not bq_client_utils.IsFailedJob(job):
_, rows = client_job.ReadSchemaAndJobRows(
client, job['jobReference'], start_row=0, max_rows=row_count
)
for i in range(len(rows)):
table_info = {}
table_info['name'] = rows[i][0]
table_info['recovery_timestamp'] = rows[i][1]
table_info['fully_replicated'] = rows[i][2] == 'true'
all_table_infos.append(table_info)
return all_table_infos
def _formatOutputString(
self,
table_reference: bq_id_utils.ApiClientHelper.TableReference,
status: str,
) -> str:
return '%s %200s' % (table_reference, status)
def _TruncateTable(
self,
table_reference: bq_id_utils.ApiClientHelper.TableReference,
recovery_timestamp: str,
is_fully_replicated: bool,
) -> str:
client = bq_cached_client.Client.Get()
kwds = {}
if not self.overwrite:
dest = bq_id_utils.ApiClientHelper.TableReference.Create(
projectId=table_reference.projectId,
datasetId=table_reference.datasetId,
tableId='_'.join(
[table_reference.tableId, 'TRUNCATED_AT', recovery_timestamp]
),
)
else:
dest = table_reference
if self.skip_fully_replicated_tables and is_fully_replicated:
self.skipped_table_count += 1
return self._formatOutputString(
table_reference, 'Fully replicated...Skipped'
)
if self.dry_run:
return self._formatOutputString(
dest, 'will be Truncated@%s' % recovery_timestamp
)
kwds = {
'write_disposition': 'WRITE_TRUNCATE',
'ignore_already_exists': False,
'operation_type': 'COPY',
}
if bq_flags.LOCATION.value:
kwds['location'] = bq_flags.LOCATION.value
source_table = bq_client_utils.GetTableReference(
id_fallbacks=client,
identifier='%s@%s' % (table_reference, recovery_timestamp),
)
job_ref = ' '
try:
job = client_job.CopyTable(client, [source_table], dest, **kwds)
if job is None:
self.failed_table_count += 1
return self._formatOutputString(dest, 'Failed')
job_ref = bq_processor_utils.ConstructObjectReference(job)
self.truncated_table_count += 1
return self._formatOutputString(dest, 'Successful %s ' % job_ref)
except bq_error.BigqueryError as e:
print(e)
self.failed_table_count += 1
return self._formatOutputString(dest, 'Failed %s ' % job_ref)