HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/394/lib/googlecloudsdk/api_lib/ai/endpoints/client.py
# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for dealing with AI Platform endpoints API."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

from apitools.base.py import encoding
from apitools.base.py import exceptions as apitools_exceptions
from apitools.base.py import extra_types
from apitools.base.py import list_pager
from googlecloudsdk.api_lib.ai import util as api_util
from googlecloudsdk.api_lib.ai.models import client as model_client
from googlecloudsdk.api_lib.util import apis
from googlecloudsdk.command_lib.ai import constants
from googlecloudsdk.command_lib.ai import errors
from googlecloudsdk.command_lib.ai import flags
from googlecloudsdk.core import exceptions as core_exceptions
from googlecloudsdk.core import properties
from googlecloudsdk.core import resources
from googlecloudsdk.core.credentials import requests
from six.moves import http_client


def _ParseModel(model_id, location_id):
  """Parses a model ID into a model resource object."""
  return resources.REGISTRY.Parse(
      model_id,
      params={
          'locationsId': location_id,
          'projectsId': properties.VALUES.core.project.GetOrFail,
      },
      collection='aiplatform.projects.locations.models',
  )


def _ConvertPyListToMessageList(message_type, values):
  return [encoding.PyValueToMessage(message_type, v) for v in values]


def _GetModelDeploymentResourceType(
    model_ref, client, shared_resources_ref=None
):
  """Gets the deployment resource type of a model.

  Args:
    model_ref: a model resource object.
    client: an apis.GetClientInstance object.
    shared_resources_ref: str, the shared deployment resource pool the model
      should use, formatted as the full URI

  Returns:
    A string which value must be 'DEDICATED_RESOURCES', 'AUTOMATIC_RESOURCES'
    or 'SHARED_RESOURCES'

  Raises:
    ArgumentError: if the model resource object is not found.
  """
  try:
    model_msg = model_client.ModelsClient(client=client).Get(model_ref)
  except apitools_exceptions.HttpError:
    raise errors.ArgumentError((
        'There is an error while getting the model information. '
        'Please make sure the model %r exists.'
        % model_ref.RelativeName()
    ))
  model_resource = encoding.MessageToPyValue(model_msg)

  #  The resource values returned in the list could be multiple.
  supported_deployment_resources_types = model_resource[
      'supportedDeploymentResourcesTypes'
  ]
  if shared_resources_ref is not None:
    if 'SHARED_RESOURCES' not in supported_deployment_resources_types:
      raise errors.ArgumentError(
          'Shared resources not supported for model {}.'.format(
              model_ref.RelativeName()
          )
      )
    else:
      return 'SHARED_RESOURCES'
  try:
    supported_deployment_resources_types.remove('SHARED_RESOURCES')
    return supported_deployment_resources_types[0]
  # Throws value error if dedicated/automatic resources was the only supported
  # resource found in list
  except ValueError:
    return model_resource['supportedDeploymentResourcesTypes'][0]


def _DoHttpPost(url, headers, body):
  """Makes an http POST request."""
  response = requests.GetSession().request(
      'POST', url, data=body, headers=headers
  )
  return response.status_code, response.headers, response.content


def _DoStreamHttpPost(url, headers, body):
  """Makes an http POST request."""
  with requests.GetSession().request(
      'POST', url, data=body, headers=headers, stream=True
  ) as resp:
    for line in resp.iter_lines():
      yield line


def _CheckIsGdcGgsModel(self, endpoint_ref):
  """GDC GGS model is only supported for GDC endpoints."""
  endpoint = self.Get(endpoint_ref)
  endpoint_resource = encoding.MessageToPyValue(endpoint)
  return (
      endpoint_resource is not None
      and 'gdcConfig' in endpoint_resource
      and 'zone' in endpoint_resource['gdcConfig']
      and endpoint_resource['gdcConfig']['zone']
  )


class EndpointsClient(object):
  """High-level client for the AI Platform endpoints surface."""

  def __init__(self, client=None, messages=None, version=None):
    self.client = client or apis.GetClientInstance(
        constants.AI_PLATFORM_API_NAME,
        constants.AI_PLATFORM_API_VERSION[version],
    )
    self.messages = messages or self.client.MESSAGES_MODULE

  def Create(
      self,
      location_ref,
      display_name,
      labels,
      description=None,
      network=None,
      endpoint_id=None,
      encryption_kms_key_name=None,
      request_response_logging_table=None,
      request_response_logging_rate=None,
  ):
    """Creates a new endpoint using v1 API.

    Args:
      location_ref: Resource, the parsed location to create an endpoint.
      display_name: str, the display name of the new endpoint.
      labels: list, the labels to organize the new endpoint.
      description: str or None, the description of the new endpoint.
      network: str, the full name of the Google Compute Engine network.
      endpoint_id: str or None, the id of the new endpoint.
      encryption_kms_key_name: str or None, the Cloud KMS resource identifier of
        the customer managed encryption key used to protect a resource.
      request_response_logging_table: str or None, the BigQuery table uri for
        request-response logging.
      request_response_logging_rate: float or None, the sampling rate for
        request-response logging.

    Returns:
      A long-running operation for Create.
    """
    encryption_spec = None
    if encryption_kms_key_name:
      encryption_spec = self.messages.GoogleCloudAiplatformV1EncryptionSpec(
          kmsKeyName=encryption_kms_key_name
      )

    endpoint = api_util.GetMessage('Endpoint', constants.GA_VERSION)(
        displayName=display_name,
        description=description,
        labels=labels,
        network=network,
        encryptionSpec=encryption_spec,
    )
    if request_response_logging_table is not None:
      endpoint.predictRequestResponseLoggingConfig = api_util.GetMessage(
          'PredictRequestResponseLoggingConfig', constants.GA_VERSION
      )(
          enabled=True,
          samplingRate=request_response_logging_rate
          if request_response_logging_rate
          else 0.0,
          bigqueryDestination=api_util.GetMessage(
              'BigQueryDestination', constants.GA_VERSION
          )(outputUri=request_response_logging_table),
      )
    req = self.messages.AiplatformProjectsLocationsEndpointsCreateRequest(
        parent=location_ref.RelativeName(),
        endpointId=endpoint_id,
        googleCloudAiplatformV1Endpoint=endpoint,
    )
    return self.client.projects_locations_endpoints.Create(req)

  def CreateBeta(
      self,
      location_ref,
      display_name,
      labels,
      description=None,
      network=None,
      endpoint_id=None,
      encryption_kms_key_name=None,
      gdce_zone=None,
      gdc_zone=None,
      request_response_logging_table=None,
      request_response_logging_rate=None,
  ):
    """Creates a new endpoint using v1beta1 API.

    Args:
      location_ref: Resource, the parsed location to create an endpoint.
      display_name: str, the display name of the new endpoint.
      labels: list, the labels to organize the new endpoint.
      description: str or None, the description of the new endpoint.
      network: str, the full name of the Google Compute Engine network.
      endpoint_id: str or None, the id of the new endpoint.
      encryption_kms_key_name: str or None, the Cloud KMS resource identifier of
        the customer managed encryption key used to protect a resource.
      gdce_zone: str or None, the name of the GDCE zone.
      gdc_zone: str or None, the name of the GDC zone.
      request_response_logging_table: str or None, the BigQuery table uri for
        request-response logging.
      request_response_logging_rate: float or None, the sampling rate for
        request-response logging.

    Returns:
      A long-running operation for Create.
    """
    encryption_spec = None
    if encryption_kms_key_name:
      encryption_spec = (
          self.messages.GoogleCloudAiplatformV1beta1EncryptionSpec(
              kmsKeyName=encryption_kms_key_name
          )
      )

    gdce_config = None
    if gdce_zone:
      gdce_config = self.messages.GoogleCloudAiplatformV1beta1GdceConfig(
          zone=gdce_zone
      )

    gdc_config = None
    if gdc_zone:
      gdc_config = self.messages.GoogleCloudAiplatformV1beta1GdcConfig(
          zone=gdc_zone
      )

    endpoint = api_util.GetMessage('Endpoint', constants.BETA_VERSION)(
        displayName=display_name,
        description=description,
        labels=labels,
        network=network,
        encryptionSpec=encryption_spec,
        gdceConfig=gdce_config,
        gdcConfig=gdc_config,
    )
    if request_response_logging_table is not None:
      endpoint.predictRequestResponseLoggingConfig = api_util.GetMessage(
          'PredictRequestResponseLoggingConfig', constants.BETA_VERSION
      )(
          enabled=True,
          samplingRate=request_response_logging_rate
          if request_response_logging_rate
          else 0.0,
          bigqueryDestination=api_util.GetMessage(
              'BigQueryDestination', constants.BETA_VERSION
          )(outputUri=request_response_logging_table),
      )
    req = self.messages.AiplatformProjectsLocationsEndpointsCreateRequest(
        parent=location_ref.RelativeName(),
        endpointId=endpoint_id,
        googleCloudAiplatformV1beta1Endpoint=endpoint,
    )
    return self.client.projects_locations_endpoints.Create(req)

  def Delete(self, endpoint_ref):
    """Deletes an existing endpoint."""
    req = self.messages.AiplatformProjectsLocationsEndpointsDeleteRequest(
        name=endpoint_ref.RelativeName()
    )
    return self.client.projects_locations_endpoints.Delete(req)

  def Get(self, endpoint_ref):
    """Gets details about an endpoint."""
    req = self.messages.AiplatformProjectsLocationsEndpointsGetRequest(
        name=endpoint_ref.RelativeName()
    )
    return self.client.projects_locations_endpoints.Get(req)

  def List(self, location_ref, filter_str=None, gdc_zone=None):
    """Lists endpoints in the project."""
    req = self.messages.AiplatformProjectsLocationsEndpointsListRequest(
        parent=location_ref.RelativeName(),
        filter=filter_str,
        gdcZone=gdc_zone,
    )
    return list_pager.YieldFromList(
        self.client.projects_locations_endpoints,
        req,
        field='endpoints',
        batch_size_attribute='pageSize',
    )

  def Patch(
      self,
      endpoint_ref,
      labels_update,
      display_name=None,
      description=None,
      traffic_split=None,
      clear_traffic_split=False,
      request_response_logging_table=None,
      request_response_logging_rate=None,
      disable_request_response_logging=False,
  ):
    """Updates an endpoint using v1 API.

    Args:
      endpoint_ref: Resource, the parsed endpoint to be updated.
      labels_update: UpdateResult, the result of applying the label diff
        constructed from args.
      display_name: str or None, the new display name of the endpoint.
      description: str or None, the new description of the endpoint.
      traffic_split: dict or None, the new traffic split of the endpoint.
      clear_traffic_split: bool, whether or not clear traffic split of the
        endpoint.
      request_response_logging_table: str or None, the BigQuery table uri for
        request-response logging.
      request_response_logging_rate: float or None, the sampling rate for
        request-response logging.
      disable_request_response_logging: bool, whether or not disable
        request-response logging of the endpoint.

    Returns:
      The response message of Patch.

    Raises:
      NoFieldsSpecifiedError: An error if no updates requested.
    """
    endpoint = api_util.GetMessage('Endpoint', constants.GA_VERSION)()
    update_mask = []

    if labels_update.needs_update:
      endpoint.labels = labels_update.labels
      update_mask.append('labels')

    if display_name is not None:
      endpoint.displayName = display_name
      update_mask.append('display_name')

    if traffic_split is not None:
      additional_properties = []
      for key, value in sorted(traffic_split.items()):
        additional_properties.append(
            endpoint.TrafficSplitValue().AdditionalProperty(
                key=key, value=value
            )
        )
      endpoint.trafficSplit = endpoint.TrafficSplitValue(
          additionalProperties=additional_properties
      )
      update_mask.append('traffic_split')

    if clear_traffic_split:
      endpoint.trafficSplit = None
      update_mask.append('traffic_split')

    if description is not None:
      endpoint.description = description
      update_mask.append('description')

    if (
        request_response_logging_table is not None
        or request_response_logging_rate is not None
    ):
      request_response_logging_config = self.Get(
          endpoint_ref
      ).predictRequestResponseLoggingConfig
      if not request_response_logging_config:
        request_response_logging_config = api_util.GetMessage(
            'PredictRequestResponseLoggingConfig', constants.GA_VERSION
        )()
      request_response_logging_config.enabled = True
      if request_response_logging_table is not None:
        request_response_logging_config.bigqueryDestination = (
            api_util.GetMessage('BigQueryDestination', constants.GA_VERSION)(
                outputUri=request_response_logging_table
            )
        )
      if request_response_logging_rate is not None:
        request_response_logging_config.samplingRate = (
            request_response_logging_rate
        )
      endpoint.predictRequestResponseLoggingConfig = (
          request_response_logging_config
      )
      update_mask.append('predict_request_response_logging_config')

    if disable_request_response_logging:
      request_response_logging_config = self.Get(
          endpoint_ref
      ).predictRequestResponseLoggingConfig
      if request_response_logging_config:
        request_response_logging_config.enabled = False
      endpoint.predictRequestResponseLoggingConfig = (
          request_response_logging_config
      )
      update_mask.append('predict_request_response_logging_config')

    if not update_mask:
      raise errors.NoFieldsSpecifiedError('No updates requested.')

    req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest(
        name=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1Endpoint=endpoint,
        updateMask=','.join(update_mask),
    )
    return self.client.projects_locations_endpoints.Patch(req)

  def PatchBeta(
      self,
      endpoint_ref,
      labels_update,
      display_name=None,
      description=None,
      traffic_split=None,
      clear_traffic_split=False,
      request_response_logging_table=None,
      request_response_logging_rate=None,
      disable_request_response_logging=False,
  ):
    """Updates an endpoint using v1beta1 API.

    Args:
      endpoint_ref: Resource, the parsed endpoint to be updated.
      labels_update: UpdateResult, the result of applying the label diff
        constructed from args.
      display_name: str or None, the new display name of the endpoint.
      description: str or None, the new description of the endpoint.
      traffic_split: dict or None, the new traffic split of the endpoint.
      clear_traffic_split: bool, whether or not clear traffic split of the
        endpoint.
      request_response_logging_table: str or None, the BigQuery table uri for
        request-response logging.
      request_response_logging_rate: float or None, the sampling rate for
        request-response logging.
      disable_request_response_logging: bool, whether or not disable
        request-response logging of the endpoint.

    Returns:
      The response message of Patch.

    Raises:
      NoFieldsSpecifiedError: An error if no updates requested.
    """
    endpoint = self.messages.GoogleCloudAiplatformV1beta1Endpoint()
    update_mask = []

    if labels_update.needs_update:
      endpoint.labels = labels_update.labels
      update_mask.append('labels')

    if display_name is not None:
      endpoint.displayName = display_name
      update_mask.append('display_name')

    if traffic_split is not None:
      additional_properties = []
      for key, value in sorted(traffic_split.items()):
        additional_properties.append(
            endpoint.TrafficSplitValue().AdditionalProperty(
                key=key, value=value
            )
        )
      endpoint.trafficSplit = endpoint.TrafficSplitValue(
          additionalProperties=additional_properties
      )
      update_mask.append('traffic_split')

    if clear_traffic_split:
      endpoint.trafficSplit = None
      update_mask.append('traffic_split')

    if description is not None:
      endpoint.description = description
      update_mask.append('description')

    if (
        request_response_logging_table is not None
        or request_response_logging_rate is not None
    ):
      request_response_logging_config = self.Get(
          endpoint_ref
      ).predictRequestResponseLoggingConfig
      if not request_response_logging_config:
        request_response_logging_config = api_util.GetMessage(
            'PredictRequestResponseLoggingConfig', constants.BETA_VERSION
        )()
      request_response_logging_config.enabled = True
      if request_response_logging_table is not None:
        request_response_logging_config.bigqueryDestination = (
            api_util.GetMessage('BigQueryDestination', constants.BETA_VERSION)(
                outputUri=request_response_logging_table
            )
        )
      if request_response_logging_rate is not None:
        request_response_logging_config.samplingRate = (
            request_response_logging_rate
        )
      endpoint.predictRequestResponseLoggingConfig = (
          request_response_logging_config
      )
      update_mask.append('predict_request_response_logging_config')

    if disable_request_response_logging:
      request_response_logging_config = self.Get(
          endpoint_ref
      ).predictRequestResponseLoggingConfig
      if request_response_logging_config:
        request_response_logging_config.enabled = False
      endpoint.predictRequestResponseLoggingConfig = (
          request_response_logging_config
      )
      update_mask.append('predict_request_response_logging_config')

    if not update_mask:
      raise errors.NoFieldsSpecifiedError('No updates requested.')

    req = self.messages.AiplatformProjectsLocationsEndpointsPatchRequest(
        name=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1Endpoint=endpoint,
        updateMask=','.join(update_mask),
    )
    return self.client.projects_locations_endpoints.Patch(req)

  def Predict(self, endpoint_ref, instances_json):
    """Sends online prediction request to an endpoint using v1 API."""
    predict_request = self.messages.GoogleCloudAiplatformV1PredictRequest(
        instances=_ConvertPyListToMessageList(
            extra_types.JsonValue, instances_json['instances']
        )
    )
    if 'parameters' in instances_json:
      predict_request.parameters = encoding.PyValueToMessage(
          extra_types.JsonValue, instances_json['parameters']
      )

    req = self.messages.AiplatformProjectsLocationsEndpointsPredictRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1PredictRequest=predict_request,
    )
    return self.client.projects_locations_endpoints.Predict(req)

  def PredictBeta(self, endpoint_ref, instances_json):
    """Sends online prediction request to an endpoint using v1beta1 API."""
    predict_request = self.messages.GoogleCloudAiplatformV1beta1PredictRequest(
        instances=_ConvertPyListToMessageList(
            extra_types.JsonValue, instances_json['instances']
        )
    )
    if 'parameters' in instances_json:
      predict_request.parameters = encoding.PyValueToMessage(
          extra_types.JsonValue, instances_json['parameters']
      )

    req = self.messages.AiplatformProjectsLocationsEndpointsPredictRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1PredictRequest=predict_request,
    )
    return self.client.projects_locations_endpoints.Predict(req)

  def RawPredict(self, endpoint_ref, headers, request):
    """Sends online raw prediction request to an endpoint."""
    url = '{}{}/{}:rawPredict'.format(
        self.client.url,
        getattr(self.client, '_VERSION'),
        endpoint_ref.RelativeName(),
    )

    status, response_headers, response = _DoHttpPost(url, headers, request)
    if status != http_client.OK:
      raise core_exceptions.Error(
          'HTTP request failed. Response:\n' + response.decode()
      )

    return response_headers, response

  def StreamRawPredict(self, endpoint_ref, headers, request):
    """Sends online raw prediction request to an endpoint."""
    url = '{}{}/{}:streamRawPredict'.format(
        self.client.url,
        getattr(self.client, '_VERSION'),
        endpoint_ref.RelativeName(),
    )

    for resp in _DoStreamHttpPost(url, headers, request):
      yield resp

  def DirectPredict(self, endpoint_ref, inputs_json):
    """Sends online direct prediction request to an endpoint using v1 API."""
    direct_predict_request = (
        self.messages.GoogleCloudAiplatformV1DirectPredictRequest(
            inputs=_ConvertPyListToMessageList(
                self.messages.GoogleCloudAiplatformV1Tensor,
                inputs_json['inputs'],
            )
        )
    )
    if 'parameters' in inputs_json:
      direct_predict_request.parameters = encoding.PyValueToMessage(
          self.messages.GoogleCloudAiplatformV1Tensor, inputs_json['parameters']
      )

    req = (
        self.messages.AiplatformProjectsLocationsEndpointsDirectPredictRequest(
            endpoint=endpoint_ref.RelativeName(),
            googleCloudAiplatformV1DirectPredictRequest=direct_predict_request,
        )
    )
    return self.client.projects_locations_endpoints.DirectPredict(req)

  def DirectPredictBeta(self, endpoint_ref, inputs_json):
    """Sends online direct prediction request to an endpoint using v1beta1 API."""
    direct_predict_request = (
        self.messages.GoogleCloudAiplatformV1beta1DirectPredictRequest(
            inputs=_ConvertPyListToMessageList(
                self.messages.GoogleCloudAiplatformV1beta1Tensor,
                inputs_json['inputs'],
            )
        )
    )
    if 'parameters' in inputs_json:
      direct_predict_request.parameters = encoding.PyValueToMessage(
          self.messages.GoogleCloudAiplatformV1beta1Tensor,
          inputs_json['parameters'],
      )

    req = self.messages.AiplatformProjectsLocationsEndpointsDirectPredictRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1DirectPredictRequest=direct_predict_request,
    )
    return self.client.projects_locations_endpoints.DirectPredict(req)

  def DirectRawPredict(self, endpoint_ref, input_json):
    """Sends online direct raw prediction request to an endpoint using v1 API."""
    direct_raw_predict_request = self.messages.GoogleCloudAiplatformV1DirectRawPredictRequest(
        input=bytes(input_json['input'], 'utf-8'),
        # Method name can be "methodName" or "method_name"
        methodName=input_json.get('methodName', input_json.get('method_name')),
    )

    req = self.messages.AiplatformProjectsLocationsEndpointsDirectRawPredictRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1DirectRawPredictRequest=direct_raw_predict_request,
    )
    return self.client.projects_locations_endpoints.DirectRawPredict(req)

  def DirectRawPredictBeta(self, endpoint_ref, input_json):
    """Sends online direct raw prediction request to an endpoint using v1beta1 API."""
    direct_raw_predict_request = self.messages.GoogleCloudAiplatformV1beta1DirectRawPredictRequest(
        input=bytes(input_json['input'], 'utf-8'),
        # Method name can be "methodName" or "method_name"
        methodName=input_json.get('methodName', input_json.get('method_name')),
    )

    req = self.messages.AiplatformProjectsLocationsEndpointsDirectRawPredictRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1DirectRawPredictRequest=direct_raw_predict_request,
    )
    return self.client.projects_locations_endpoints.DirectRawPredict(req)

  def Explain(self, endpoint_ref, instances_json, args):
    """Sends online explanation request to an endpoint using v1beta1 API."""
    explain_request = self.messages.GoogleCloudAiplatformV1ExplainRequest(
        instances=_ConvertPyListToMessageList(
            extra_types.JsonValue, instances_json['instances']
        )
    )
    if 'parameters' in instances_json:
      explain_request.parameters = encoding.PyValueToMessage(
          extra_types.JsonValue, instances_json['parameters']
      )
    if args.deployed_model_id is not None:
      explain_request.deployedModelId = args.deployed_model_id

    req = self.messages.AiplatformProjectsLocationsEndpointsExplainRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1ExplainRequest=explain_request,
    )
    return self.client.projects_locations_endpoints.Explain(req)

  def ExplainBeta(self, endpoint_ref, instances_json, args):
    """Sends online explanation request to an endpoint using v1beta1 API."""
    explain_request = self.messages.GoogleCloudAiplatformV1beta1ExplainRequest(
        instances=_ConvertPyListToMessageList(
            extra_types.JsonValue, instances_json['instances']
        )
    )
    if 'parameters' in instances_json:
      explain_request.parameters = encoding.PyValueToMessage(
          extra_types.JsonValue, instances_json['parameters']
      )
    if 'explanation_spec_override' in instances_json:
      explain_request.explanationSpecOverride = encoding.PyValueToMessage(
          self.messages.GoogleCloudAiplatformV1beta1ExplanationSpecOverride,
          instances_json['explanation_spec_override'],
      )
    if args.deployed_model_id is not None:
      explain_request.deployedModelId = args.deployed_model_id

    req = self.messages.AiplatformProjectsLocationsEndpointsExplainRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1ExplainRequest=explain_request,
    )
    return self.client.projects_locations_endpoints.Explain(req)

  def DeployModel(
      self,
      endpoint_ref,
      model,
      region,
      display_name,
      machine_type=None,
      tpu_topology=None,
      multihost_gpu_node_count=None,
      accelerator_dict=None,
      min_replica_count=None,
      max_replica_count=None,
      required_replica_count=None,
      reservation_affinity=None,
      autoscaling_metric_specs=None,
      spot=False,
      enable_access_logging=False,
      disable_container_logging=False,
      service_account=None,
      traffic_split=None,
      deployed_model_id=None,
  ):
    """Deploys a model to an existing endpoint using v1 API.

    Args:
      endpoint_ref: Resource, the parsed endpoint that the model is deployed to.
      model: str, Id of the uploaded model to be deployed.
      region: str, the location of the endpoint and the model.
      display_name: str, the display name of the new deployed model.
      machine_type: str or None, the type of the machine to serve the model.
      tpu_topology: str or None, the topology of the TPU to serve the model.
      multihost_gpu_node_count: int or None, the number of nodes per replica for
        multihost GPU deployments.
      accelerator_dict: dict or None, the accelerator attached to the deployed
        model from args.
      min_replica_count: int or None, the minimum number of replicas the
        deployed model will be always deployed on.
      max_replica_count: int or None, the maximum number of replicas the
        deployed model may be deployed on.
      required_replica_count: int or None, the required number of replicas the
        deployed model will be considered successfully deployed.
      reservation_affinity: dict or None, the reservation affinity of the
        deployed model which specifies which reservations the deployed model can
        use.
      autoscaling_metric_specs: dict or None, the metric specification that
        defines the target resource utilization for calculating the desired
        replica count.
      spot: bool, whether or not deploy the model on spot resources.
      enable_access_logging: bool, whether or not enable access logs.
      disable_container_logging: bool, whether or not disable container logging.
      service_account: str or None, the service account that the deployed model
        runs as.
      traffic_split: dict or None, the new traffic split of the endpoint.
      deployed_model_id: str or None, id of the deployed model.

    Returns:
      A long-running operation for DeployModel.
    """
    model_ref = _ParseModel(model, region)

    resource_type = _GetModelDeploymentResourceType(model_ref, self.client)
    if resource_type == 'DEDICATED_RESOURCES':
      # dedicated resources
      machine_spec = self.messages.GoogleCloudAiplatformV1MachineSpec()
      if machine_type is not None:
        machine_spec.machineType = machine_type
      if tpu_topology is not None:
        machine_spec.tpuTopology = tpu_topology
      if multihost_gpu_node_count is not None:
        machine_spec.multihostGpuNodeCount = multihost_gpu_node_count
      accelerator = flags.ParseAcceleratorFlag(
          accelerator_dict, constants.GA_VERSION
      )
      if accelerator is not None:
        machine_spec.acceleratorType = accelerator.acceleratorType
        machine_spec.acceleratorCount = accelerator.acceleratorCount
      if reservation_affinity is not None:
        machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag(
            reservation_affinity, constants.GA_VERSION
        )

      dedicated = self.messages.GoogleCloudAiplatformV1DedicatedResources(
          machineSpec=machine_spec, spot=spot
      )
      # min-replica-count is required and must be >= 1 if models use dedicated
      # resources. Default to 1 if not specified.
      dedicated.minReplicaCount = min_replica_count or 1
      if max_replica_count is not None:
        dedicated.maxReplicaCount = max_replica_count
      if required_replica_count is not None:
        dedicated.requiredReplicaCount = required_replica_count

      if autoscaling_metric_specs is not None:
        autoscaling_metric_specs_list = []
        for name, target in sorted(autoscaling_metric_specs.items()):
          autoscaling_metric_specs_list.append(
              self.messages.GoogleCloudAiplatformV1AutoscalingMetricSpec(
                  metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[name],
                  target=target,
              )
          )
        dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list

      deployed_model = self.messages.GoogleCloudAiplatformV1DeployedModel(
          dedicatedResources=dedicated,
          displayName=display_name,
          model=model_ref.RelativeName(),
      )
    else:
      # automatic resources
      automatic = self.messages.GoogleCloudAiplatformV1AutomaticResources()
      if min_replica_count is not None:
        automatic.minReplicaCount = min_replica_count
      if max_replica_count is not None:
        automatic.maxReplicaCount = max_replica_count

      deployed_model = self.messages.GoogleCloudAiplatformV1DeployedModel(
          automaticResources=automatic,
          displayName=display_name,
          model=model_ref.RelativeName(),
      )

    deployed_model.enableAccessLogging = enable_access_logging
    deployed_model.disableContainerLogging = disable_container_logging

    if service_account is not None:
      deployed_model.serviceAccount = service_account

    if deployed_model_id is not None:
      deployed_model.id = deployed_model_id

    deployed_model_req = (
        self.messages.GoogleCloudAiplatformV1DeployModelRequest(
            deployedModel=deployed_model
        )
    )

    if traffic_split is not None:
      additional_properties = []
      for key, value in sorted(traffic_split.items()):
        additional_properties.append(
            deployed_model_req.TrafficSplitValue().AdditionalProperty(
                key=key, value=value
            )
        )
      deployed_model_req.trafficSplit = deployed_model_req.TrafficSplitValue(
          additionalProperties=additional_properties
      )

    req = self.messages.AiplatformProjectsLocationsEndpointsDeployModelRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1DeployModelRequest=deployed_model_req,
    )
    return self.client.projects_locations_endpoints.DeployModel(req)

  def DeployModelBeta(
      self,
      endpoint_ref,
      model,
      region,
      display_name,
      machine_type=None,
      tpu_topology=None,
      multihost_gpu_node_count=None,
      accelerator_dict=None,
      min_replica_count=None,
      max_replica_count=None,
      required_replica_count=None,
      reservation_affinity=None,
      autoscaling_metric_specs=None,
      spot=False,
      enable_access_logging=False,
      enable_container_logging=False,
      service_account=None,
      traffic_split=None,
      deployed_model_id=None,
      shared_resources_ref=None,
      min_scaleup_period=None,
      idle_scaledown_period=None,
      initial_replica_count=None,
      gpu_partition_size=None,
  ):
    """Deploys a model to an existing endpoint using v1beta1 API.

    Args:
      endpoint_ref: Resource, the parsed endpoint that the model is deployed to.
      model: str, Id of the uploaded model to be deployed.
      region: str, the location of the endpoint and the model.
      display_name: str, the display name of the new deployed model.
      machine_type: str or None, the type of the machine to serve the model.
      tpu_topology: str or None, the topology of the TPU to serve the model.
      multihost_gpu_node_count: int or None, the number of nodes per replica for
        multihost GPU deployments.
      accelerator_dict: dict or None, the accelerator attached to the deployed
        model from args.
      min_replica_count: int or None, the minimum number of replicas the
        deployed model will be always deployed on.
      max_replica_count: int or None, the maximum number of replicas the
        deployed model may be deployed on.
      required_replica_count: int or None, the required number of replicas the
        deployed model will be considered successfully deployed.
      reservation_affinity: dict or None, the reservation affinity of the
        deployed model which specifies which reservations the deployed model can
        use.
      autoscaling_metric_specs: dict or None, the metric specification that
        defines the target resource utilization for calculating the desired
        replica count.
      spot: bool, whether or not deploy the model on spot resources.
      enable_access_logging: bool, whether or not enable access logs.
      enable_container_logging: bool, whether or not enable container logging.
      service_account: str or None, the service account that the deployed model
        runs as.
      traffic_split: dict or None, the new traffic split of the endpoint.
      deployed_model_id: str or None, id of the deployed model.
      shared_resources_ref: str or None, the shared deployment resource pool the
        model should use
      min_scaleup_period: str or None, the minimum duration (in seconds) that a
        deployment will be scaled up before traffic is evaluated for potential
        scale-down. Defaults to 1 hour if min replica count is 0.
      idle_scaledown_period: str or None, the duration after which the
        deployment is scaled down if no traffic is received. This only applies
        to deployments enrolled in scale-to-zero.
      initial_replica_count: int or None, the initial number of replicas the
        deployment will be scaled up to. This only applies to deployments
        enrolled in scale-to-zero.
      gpu_partition_size: str or None, the partition size of the GPU
        accelerator.

    Returns:
      A long-running operation for DeployModel.
    """
    is_gdc_ggs_model = _CheckIsGdcGgsModel(self, endpoint_ref)
    if is_gdc_ggs_model:
      # send psudo dedicated resources for gdc ggs model.
      machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec(
          machineType='n1-standard-2',
          acceleratorType=self.messages.GoogleCloudAiplatformV1beta1MachineSpec.AcceleratorTypeValueValuesEnum.NVIDIA_TESLA_T4,
          acceleratorCount=1,
      )
      dedicated = self.messages.GoogleCloudAiplatformV1beta1DedicatedResources(
          machineSpec=machine_spec, minReplicaCount=1, maxReplicaCount=1
      )
      deployed_model = self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
          dedicatedResources=dedicated,
          displayName=display_name,
          gdcConnectedModel=model,
      )
    else:
      model_ref = _ParseModel(model, region)
      resource_type = _GetModelDeploymentResourceType(
          model_ref, self.client, shared_resources_ref
      )
      if resource_type == 'DEDICATED_RESOURCES':
        # dedicated resources
        machine_spec = self.messages.GoogleCloudAiplatformV1beta1MachineSpec()
        if machine_type is not None:
          machine_spec.machineType = machine_type
        if tpu_topology is not None:
          machine_spec.tpuTopology = tpu_topology
        if multihost_gpu_node_count is not None:
          machine_spec.multihostGpuNodeCount = multihost_gpu_node_count
        accelerator = flags.ParseAcceleratorFlag(
            accelerator_dict, constants.BETA_VERSION
        )
        if accelerator is not None:
          machine_spec.acceleratorType = accelerator.acceleratorType
          machine_spec.acceleratorCount = accelerator.acceleratorCount
        if reservation_affinity is not None:
          machine_spec.reservationAffinity = flags.ParseReservationAffinityFlag(
              reservation_affinity, constants.BETA_VERSION
          )
        if gpu_partition_size is not None:
          machine_spec.gpuPartitionSize = gpu_partition_size

        dedicated = (
            self.messages.GoogleCloudAiplatformV1beta1DedicatedResources(
                machineSpec=machine_spec, spot=spot
            )
        )
        # min-replica-count is required and must be >= 0 if models use dedicated
        # resources. If value is 0, the deployment will be enrolled in the
        # scale-to-zero feature. Default to 1 if not specified.
        dedicated.minReplicaCount = (
            1 if min_replica_count is None else min_replica_count
        )

        # if not specified and min-replica-count is 0, default to 1.
        if max_replica_count is None and dedicated.minReplicaCount == 0:
          dedicated.maxReplicaCount = 1
        else:
          if max_replica_count is not None:
            dedicated.maxReplicaCount = max_replica_count
        if required_replica_count is not None:
          dedicated.requiredReplicaCount = required_replica_count

        # if not specified and min-replica-count is 0, default to 1.
        if initial_replica_count is None and dedicated.minReplicaCount == 0:
          dedicated.initialReplicaCount = 1
        else:
          if initial_replica_count is not None:
            dedicated.initialReplicaCount = initial_replica_count

        if autoscaling_metric_specs is not None:
          autoscaling_metric_specs_list = []
          for name, target in sorted(autoscaling_metric_specs.items()):
            autoscaling_metric_specs_list.append(
                self.messages.GoogleCloudAiplatformV1beta1AutoscalingMetricSpec(
                    metricName=constants.OP_AUTOSCALING_METRIC_NAME_MAPPER[
                        name
                    ],
                    target=target,
                )
            )
          dedicated.autoscalingMetricSpecs = autoscaling_metric_specs_list

        stz_spec = (
            self.messages.GoogleCloudAiplatformV1beta1DedicatedResourcesScaleToZeroSpec()
        )
        stz_spec_modified = False
        if min_scaleup_period is not None:
          stz_spec.minScaleupPeriod = '{}s'.format(min_scaleup_period)
          stz_spec_modified = True
        if idle_scaledown_period is not None:
          stz_spec.idleScaledownPeriod = '{}s'.format(idle_scaledown_period)
          stz_spec_modified = True

        if stz_spec_modified:
          dedicated.scaleToZeroSpec = stz_spec

        deployed_model = (
            self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
                dedicatedResources=dedicated,
                displayName=display_name,
                model=model_ref.RelativeName(),
            )
        )
      elif resource_type == 'AUTOMATIC_RESOURCES':
        # automatic resources
        automatic = (
            self.messages.GoogleCloudAiplatformV1beta1AutomaticResources()
        )
        if min_replica_count is not None:
          automatic.minReplicaCount = min_replica_count
        if max_replica_count is not None:
          automatic.maxReplicaCount = max_replica_count

        deployed_model = (
            self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
                automaticResources=automatic,
                displayName=display_name,
                model=model_ref.RelativeName(),
            )
        )
      # if resource type is SHARED_RESOURCES
      else:
        deployed_model = (
            self.messages.GoogleCloudAiplatformV1beta1DeployedModel(
                displayName=display_name,
                model=model_ref.RelativeName(),
                sharedResources=shared_resources_ref.RelativeName(),
            )
        )

    deployed_model.enableAccessLogging = enable_access_logging
    deployed_model.enableContainerLogging = enable_container_logging

    if service_account is not None:
      deployed_model.serviceAccount = service_account

    if deployed_model_id is not None:
      deployed_model.id = deployed_model_id

    deployed_model_req = (
        self.messages.GoogleCloudAiplatformV1beta1DeployModelRequest(
            deployedModel=deployed_model
        )
    )

    if traffic_split is not None:
      additional_properties = []
      for key, value in sorted(traffic_split.items()):
        additional_properties.append(
            deployed_model_req.TrafficSplitValue().AdditionalProperty(
                key=key, value=value
            )
        )
      deployed_model_req.trafficSplit = deployed_model_req.TrafficSplitValue(
          additionalProperties=additional_properties
      )

    req = self.messages.AiplatformProjectsLocationsEndpointsDeployModelRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1DeployModelRequest=deployed_model_req,
    )
    return self.client.projects_locations_endpoints.DeployModel(req)

  def UndeployModel(self, endpoint_ref, deployed_model_id, traffic_split=None):
    """Undeploys a model from an endpoint using v1 API.

    Args:
      endpoint_ref: Resource, the parsed endpoint that the model is undeployed
        from.
      deployed_model_id: str, Id of the deployed model to be undeployed.
      traffic_split: dict or None, the new traffic split of the endpoint.

    Returns:
      A long-running operation for UndeployModel.
    """
    undeployed_model_req = (
        self.messages.GoogleCloudAiplatformV1UndeployModelRequest(
            deployedModelId=deployed_model_id
        )
    )

    if traffic_split is not None:
      additional_properties = []
      for key, value in sorted(traffic_split.items()):
        additional_properties.append(
            undeployed_model_req.TrafficSplitValue().AdditionalProperty(
                key=key, value=value
            )
        )
      undeployed_model_req.trafficSplit = (
          undeployed_model_req.TrafficSplitValue(
              additionalProperties=additional_properties
          )
      )

    req = (
        self.messages.AiplatformProjectsLocationsEndpointsUndeployModelRequest(
            endpoint=endpoint_ref.RelativeName(),
            googleCloudAiplatformV1UndeployModelRequest=undeployed_model_req,
        )
    )
    return self.client.projects_locations_endpoints.UndeployModel(req)

  def UndeployModelBeta(
      self, endpoint_ref, deployed_model_id, traffic_split=None
  ):
    """Undeploys a model from an endpoint using v1beta1 API.

    Args:
      endpoint_ref: Resource, the parsed endpoint that the model is undeployed
        from.
      deployed_model_id: str, Id of the deployed model to be undeployed.
      traffic_split: dict or None, the new traffic split of the endpoint.

    Returns:
      A long-running operation for UndeployModel.
    """
    undeployed_model_req = (
        self.messages.GoogleCloudAiplatformV1beta1UndeployModelRequest(
            deployedModelId=deployed_model_id
        )
    )

    if traffic_split is not None:
      additional_properties = []
      for key, value in sorted(traffic_split.items()):
        additional_properties.append(
            undeployed_model_req.TrafficSplitValue().AdditionalProperty(
                key=key, value=value
            )
        )
      undeployed_model_req.trafficSplit = (
          undeployed_model_req.TrafficSplitValue(
              additionalProperties=additional_properties
          )
      )

    req = self.messages.AiplatformProjectsLocationsEndpointsUndeployModelRequest(
        endpoint=endpoint_ref.RelativeName(),
        googleCloudAiplatformV1beta1UndeployModelRequest=undeployed_model_req,
    )
    return self.client.projects_locations_endpoints.UndeployModel(req)