HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/current/lib/third_party/ml_sdk/cloud/ml/prediction/prediction_lib.py
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running predictions.

Includes (from the Cloud ML SDK):
- _predict_lib

Important changes:
- Remove interfaces for TensorFlowModel (they don't change behavior).
- Set from_client(skip_preprocessing=True) and remove the pre-processing code.
"""
from . import custom_code_utils
from . import prediction_utils


# --------------------------
# prediction.prediction_lib
# --------------------------
def create_model(client, model_path, framework=None, **unused_kwargs):
  """Creates and returns the appropriate model.

  Creates and returns a Model if no user specified model is
  provided. Otherwise, the user specified model is imported, created, and
  returned.

  Args:
    client: An instance of PredictionClient for performing prediction.
    model_path: The path to the exported model (e.g. session_bundle or
      SavedModel)
    framework: The framework used to train the model.

  Returns:
    An instance of the appropriate model class.
  """
  custom_model = custom_code_utils.create_user_model(model_path, None)
  if custom_model:
    return custom_model

  framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME

  if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME:
    from .frameworks import tf_prediction_lib  # pylint: disable=g-import-not-at-top
    model_cls = tf_prediction_lib.TensorFlowModel
  elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
    model_cls = sk_xg_prediction_lib.SklearnModel
  elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
    model_cls = sk_xg_prediction_lib.XGBoostModel

  return model_cls(client)


def create_client(framework, model_path, **kwargs):
  """Creates and returns the appropriate prediction client.

  Creates and returns a PredictionClient based on the provided framework.

  Args:
    framework: The framework used to train the model.
    model_path: The path to the exported model (e.g. session_bundle or
      SavedModel)
    **kwargs: Optional additional params to pass to the client constructor (such
      as TF tags).

  Returns:
    An instance of the appropriate PredictionClient.
  """
  framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
  if framework == prediction_utils.TENSORFLOW_FRAMEWORK_NAME:
    from .frameworks import tf_prediction_lib  # pylint: disable=g-import-not-at-top
    create_client_fn = tf_prediction_lib.create_tf_session_client
  elif framework == prediction_utils.SCIKIT_LEARN_FRAMEWORK_NAME:
    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
    create_client_fn = sk_xg_prediction_lib.create_sklearn_client
  elif framework == prediction_utils.XGBOOST_FRAMEWORK_NAME:
    from .frameworks import sk_xg_prediction_lib  # pylint: disable=g-import-not-at-top
    create_client_fn = sk_xg_prediction_lib.create_xgboost_client

  return create_client_fn(model_path, **kwargs)


def local_predict(model_dir=None, signature_name=None, instances=None,
                  framework=None, **kwargs):
  """Run a prediction locally."""
  framework = framework or prediction_utils.TENSORFLOW_FRAMEWORK_NAME
  client = create_client(framework, model_dir, **kwargs)
  model = create_model(client, model_dir, framework)
  if prediction_utils.should_base64_decode(framework, model, signature_name):
    instances = prediction_utils.decode_base64(instances)
  predictions = model.predict(instances, signature_name=signature_name)
  return {"predictions": list(predictions)}