HEX
Server: Apache/2.4.65 (Ubuntu)
System: Linux ielts-store-v2 6.8.0-1036-gcp #38~22.04.1-Ubuntu SMP Thu Aug 14 01:19:18 UTC 2025 x86_64
User: root (0)
PHP: 7.2.34-54+ubuntu20.04.1+deb.sury.org+1
Disabled: pcntl_alarm,pcntl_fork,pcntl_waitpid,pcntl_wait,pcntl_wifexited,pcntl_wifstopped,pcntl_wifsignaled,pcntl_wifcontinued,pcntl_wexitstatus,pcntl_wtermsig,pcntl_wstopsig,pcntl_signal,pcntl_signal_get_handler,pcntl_signal_dispatch,pcntl_get_last_error,pcntl_strerror,pcntl_sigprocmask,pcntl_sigwaitinfo,pcntl_sigtimedwait,pcntl_exec,pcntl_getpriority,pcntl_setpriority,pcntl_async_signals,
Upload Files
File: //snap/google-cloud-cli/current/lib/googlecloudsdk/command_lib/compute/tpus/queued_resources/ssh.py
# -*- coding: utf-8 -*- #
# Copyright 2023 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SSH/SCP utilities for Cloud TPU Queued Resource commands."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

from googlecloudsdk.calliope import exceptions
from googlecloudsdk.core import log
import six


def ParseNodeFlag(node_flag, node_specs):
  """Parses the --node flag into a list of node_specs."""
  num_nodes = len(node_specs)
  if six.text_type(node_flag).upper() == 'ALL':
    indexes = list(range(num_nodes))
  else:
    indexes = set()
    ranges = node_flag.split(',')
    for r in ranges:
      if not r:
        continue
      if '-' in r:
        bounds = r.split('-')
        if len(bounds) != 2 or not bounds[0] or not bounds[1]:
          raise exceptions.InvalidArgumentException(
              '--node',
              'Range "{}" does not match expected format'
              ' "lowerBound-upperBound", where lowerBound < upperBound.'.format(
                  r
              ),
          )
        start, end = int(bounds[0]), int(bounds[1])
        if start >= end:
          raise exceptions.InvalidArgumentException(
              '--node',
              'Range "{}" does not match expected format'
              ' "lowerBound-upperBound", where lowerBound < upperBound.'.format(
                  r
              ),
          )
        indexes.update(range(start, end + 1))
      else:
        try:
          indexes.add(int(r))
        except ValueError:
          raise exceptions.InvalidArgumentException(
              '--node',
              'unable to parse node ID {}. Please only use numbers.'.format(r),
          )

  if not indexes:
    raise exceptions.InvalidArgumentException(
        '--node',
        'Unable to parse node ranges from {}.'.format(node_flag),
    )

  mx = max(indexes)
  if mx >= num_nodes:
    raise exceptions.InvalidArgumentException(
        '--node',
        'node index {} is larger than the valid node indices on this TPU Queued'
        ' Resource. Please only use indexes in the range [0, {}], inclusive.'
        .format(mx, num_nodes - 1),
    )

  # Get the filtered node specs.
  filtered_node_specs = []
  for node in indexes:
    filtered_node_specs.append(node_specs[node])
  return filtered_node_specs


def WaitForNodeBatchCompletion(ssh_threads, nodes):
  """Waits for the completion of batch, but does not block for failures.

  Args:
    ssh_threads: List of ssh threads.
    nodes: List of SSH prepped nodes.
  """
  for ssh_thread in ssh_threads:
    ssh_thread.join()

  for node in nodes:
    if node:
      log.status.Print('Finished preparing node {}.'.format(node.tpu_name))