HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/spanner/database_sessions.py
# -*- coding: utf-8 -*- #
# Copyright 2016 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Spanner database sessions API helper."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import extra_types
from apitools.base.py import http_wrapper
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.spanner.sql import QueryHasDml


def CheckResponse(response):
  """Wrap http_wrapper.CheckResponse to skip retry on 501."""
  if response.status_code == 501:
    raise apitools_exceptions.HttpError.FromResponse(response)
  return http_wrapper.CheckResponse(response)


def Create(database_ref, creator_role=None):
  """Create a database session.

  Args:
    database_ref: String, The database in which the new session is created.
    creator_role: String, The database role which created this session.

  Returns:
    Newly created session.
  """
  client = _GetClientInstance('spanner', 'v1', None)
  msgs = apis.GetMessagesModule('spanner', 'v1')
  if creator_role is None:
    req = msgs.SpannerProjectsInstancesDatabasesSessionsCreateRequest(
        database=database_ref.RelativeName())
  else:
    create_session_request = msgs.CreateSessionRequest(
        session=msgs.Session(creatorRole=creator_role))
    req = msgs.SpannerProjectsInstancesDatabasesSessionsCreateRequest(
        createSessionRequest=create_session_request,
        database=database_ref.RelativeName())
  return client.projects_instances_databases_sessions.Create(req)


def List(database_ref, server_filter=None):
  """Lists all active sessions on the given database."""
  client = apis.GetClientInstance('spanner', 'v1')
  msgs = apis.GetMessagesModule('spanner', 'v1')
  req = msgs.SpannerProjectsInstancesDatabasesSessionsListRequest(
      database=database_ref.RelativeName(), filter=server_filter)

  return list_pager.YieldFromList(
      client.projects_instances_databases_sessions,
      req,
      # There is a batch_size_attribute ('pageSize') but we want to yield as
      # many results as possible per request.
      batch_size_attribute=None,
      field='sessions')


def Delete(session_ref):
  """Delete a database session."""
  client = apis.GetClientInstance('spanner', 'v1')
  msgs = apis.GetMessagesModule('spanner', 'v1')
  req = msgs.SpannerProjectsInstancesDatabasesSessionsDeleteRequest(
      name=session_ref.RelativeName())
  return client.projects_instances_databases_sessions.Delete(req)


def _GetClientInstance(api_name, api_version, http_timeout_sec=None):
  client = apis.GetClientInstance(
      api_name, api_version, http_timeout_sec=http_timeout_sec)
  client.check_response_func = CheckResponse
  return client


def ExecuteSql(sql, query_mode, session_ref, read_only_options=None,
               request_options=None, enable_partitioned_dml=False,
               http_timeout_sec=None):
  """Execute an SQL command.

  Args:
    sql: String, The SQL to execute.
    query_mode: String, The mode in which to run the query. Must be one of
      'NORMAL', 'PLAN', 'PROFILE', 'WITH_STATS', or 'WITH_PLAN_AND_STATS'.
    session_ref: Session, Indicates that the repo should be created if it does
      not exist.
    read_only_options: The ReadOnly message for a read-only request. It is
      ignored in a DML request.
    request_options: The RequestOptions message that contains the priority.
    enable_partitioned_dml: Boolean, whether partitioned dml is enabled.
    http_timeout_sec: int, Maximum time in seconds to wait for the SQL query to
      complete.

  Returns:
    (Repo) The capture repository.
  """
  client = _GetClientInstance('spanner', 'v1', http_timeout_sec)
  msgs = apis.GetMessagesModule('spanner', 'v1')
  _RegisterCustomMessageCodec(msgs)

  execute_sql_request = _GetQueryRequest(
      sql,
      query_mode,
      session_ref,
      read_only_options,
      request_options,
      enable_partitioned_dml,
  )
  req = msgs.SpannerProjectsInstancesDatabasesSessionsExecuteSqlRequest(
      session=session_ref.RelativeName(), executeSqlRequest=execute_sql_request)
  resp = client.projects_instances_databases_sessions.ExecuteSql(req)
  if QueryHasDml(sql) and enable_partitioned_dml is False:
    result_set = msgs.ResultSet(metadata=resp.metadata)
    Commit(session_ref, [], result_set.metadata.transaction.id)
  return resp


def _RegisterCustomMessageCodec(msgs):
  """Register custom message code.

  Args:
    msgs: Spanner v1 messages.
  """
  # TODO(b/33482229): remove this workaround
  def _ToJson(msg):
    return extra_types.JsonProtoEncoder(
        extra_types.JsonArray(entries=msg.entry))
  def _FromJson(data):
    return msgs.ResultSet.RowsValueListEntry(
        entry=extra_types.JsonProtoDecoder(data).entries)
  encoding.RegisterCustomMessageCodec(
      encoder=_ToJson, decoder=_FromJson)(
          msgs.ResultSet.RowsValueListEntry)


def _GetQueryRequest(sql,
                     query_mode,
                     session_ref=None,
                     read_only_options=None,
                     request_options=None,
                     enable_partitioned_dml=False):
  """Formats the request based on whether the statement contains DML.

  Args:
    sql: String, The SQL to execute.
    query_mode: String, The mode in which to run the query. Must be one of
      'NORMAL', 'PLAN', 'PROFILE', 'WITH_STATS', or 'WITH_PLAN_AND_STATS'.
    session_ref: Reference to the session.
    read_only_options: The ReadOnly message for a read-only request. It is
      ignored in a DML request.
    request_options: The RequestOptions message that contains the priority.
    enable_partitioned_dml: Boolean, whether partitioned dml is enabled.

  Returns:
    ExecuteSqlRequest parameters
  """
  msgs = apis.GetMessagesModule('spanner', 'v1')

  if enable_partitioned_dml is True:
    transaction = _GetPartitionedDmlTransaction(session_ref)
  elif QueryHasDml(sql):
    transaction_options = msgs.TransactionOptions(readWrite=msgs.ReadWrite())
    transaction = msgs.TransactionSelector(begin=transaction_options)
  else:
    transaction_options = msgs.TransactionOptions(
        readOnly=read_only_options)
    transaction = msgs.TransactionSelector(singleUse=transaction_options)
  return msgs.ExecuteSqlRequest(
      sql=sql,
      requestOptions=request_options,
      queryMode=msgs.ExecuteSqlRequest.QueryModeValueValuesEnum(query_mode),
      transaction=transaction)


def _GetPartitionedDmlTransaction(session_ref):
  """Creates a transaction for Partitioned DML.

  Args:
    session_ref: Reference to the session.

  Returns:
    TransactionSelector with the id property.
  """
  client = apis.GetClientInstance('spanner', 'v1')
  msgs = apis.GetMessagesModule('spanner', 'v1')

  transaction_options = msgs.TransactionOptions(
      partitionedDml=msgs.PartitionedDml())
  begin_transaction_req = msgs.BeginTransactionRequest(
      options=transaction_options)
  req = msgs.SpannerProjectsInstancesDatabasesSessionsBeginTransactionRequest(
      beginTransactionRequest=begin_transaction_req,
      session=session_ref.RelativeName())
  resp = client.projects_instances_databases_sessions.BeginTransaction(req)
  return msgs.TransactionSelector(id=resp.id)


def Commit(session_ref, mutations, transaction_id=None):
  """Commit a transaction through a session.

  In Cloud Spanner, each session can have at most one active transaction at a
  time. In order to avoid retrying aborted transactions by accident, this
  request uses a temporary single use transaction instead of a previously
  started transaction to execute the mutations.
  Note: this commit is non-idempotent.

  Args:
    session_ref: Session, through which the transaction would be committed.
    mutations: A list of mutations, each represents a modification to one or
        more Cloud Spanner rows.
    transaction_id: An optional string for the transaction id.

  Returns:
    The Cloud Spanner timestamp at which the transaction committed.
  """
  client = apis.GetClientInstance('spanner', 'v1')
  msgs = apis.GetMessagesModule('spanner', 'v1')

  if transaction_id is not None:
    req = msgs.SpannerProjectsInstancesDatabasesSessionsCommitRequest(
        session=session_ref.RelativeName(),
        commitRequest=msgs.CommitRequest(
            mutations=mutations, transactionId=transaction_id))
  else:
    req = msgs.SpannerProjectsInstancesDatabasesSessionsCommitRequest(
        session=session_ref.RelativeName(),
        commitRequest=msgs.CommitRequest(
            mutations=mutations,
            singleUseTransaction=msgs.TransactionOptions(
                readWrite=msgs.ReadWrite())))
  return client.projects_instances_databases_sessions.Commit(req)


class MutationFactory(object):
  """Factory that creates and returns a mutation object in Cloud Spanner.

  A Mutation represents a sequence of inserts, updates and deletes that can be
  applied to rows and tables in a Cloud Spanner database.
  """
  msgs = apis.GetMessagesModule('spanner', 'v1')

  @classmethod
  def Insert(cls, table, data):
    """Constructs an INSERT mutation, which inserts a new row in a table.

    Args:
      table: String, the name of the table.
      data: A collections.OrderedDict, the keys of which are the column names
        and values are the column values to be inserted.

    Returns:
      An insert mutation operation.
    """
    return cls.msgs.Mutation(insert=cls._GetWrite(table, data))

  @classmethod
  def Update(cls, table, data):
    """Constructs an UPDATE mutation, which updates a row in a table.

    Args:
      table: String, the name of the table.
      data: An ordered dictionary where the keys are the column names and values
        are the column values to be updated.

    Returns:
      An update mutation operation.
    """
    return cls.msgs.Mutation(update=cls._GetWrite(table, data))

  @classmethod
  def Delete(cls, table, keys):
    """Constructs a DELETE mutation, which deletes a row in a table.

    Args:
      table: String, the name of the table.
      keys: String list, the primary key values of the row to delete.

    Returns:
      A delete mutation operation.
    """
    return cls.msgs.Mutation(delete=cls._GetDelete(table, keys))

  @classmethod
  def _GetWrite(cls, table, data):
    """Constructs Write object, which is needed for insert/update operations."""
    # TODO(b/33482229): a workaround to handle JSON serialization
    def _ToJson(msg):
      return extra_types.JsonProtoEncoder(
          extra_types.JsonArray(entries=msg.entry))

    encoding.RegisterCustomMessageCodec(
        encoder=_ToJson, decoder=None)(
            cls.msgs.Write.ValuesValueListEntry)

    json_columns = table.GetJsonData(data)
    json_column_names = [col.col_name for col in json_columns]
    json_column_values = [col.col_value for col in json_columns]

    return cls.msgs.Write(
        columns=json_column_names,
        table=table.name,
        values=[cls.msgs.Write.ValuesValueListEntry(entry=json_column_values)])

  @classmethod
  def _GetDelete(cls, table, keys):
    """Constructs Delete object, which is needed for delete operation."""

    # TODO(b/33482229): a workaround to handle JSON serialization
    def _ToJson(msg):
      return extra_types.JsonProtoEncoder(
          extra_types.JsonArray(entries=msg.entry))

    encoding.RegisterCustomMessageCodec(
        encoder=_ToJson, decoder=None)(
            cls.msgs.KeySet.KeysValueListEntry)

    key_set = cls.msgs.KeySet(keys=[
        cls.msgs.KeySet.KeysValueListEntry(entry=table.GetJsonKeys(keys))
    ])

    return cls.msgs.Delete(table=table.name, keySet=key_set)