File: //snap/google-cloud-cli/current/lib/googlecloudsdk/command_lib/compute/iap_tunnel.py
# -*- coding: utf-8 -*- #
# Copyright 2018 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tunnel TCP traffic over Cloud IAP WebSocket connection."""
from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals
import ctypes
import errno
import functools
import gc
import io
import os
import select
import socket
import sys
import threading
from googlecloudsdk.api_lib.compute import iap_tunnel_websocket
from googlecloudsdk.api_lib.compute import iap_tunnel_websocket_utils as utils
from googlecloudsdk.api_lib.compute import sg_tunnel
from googlecloudsdk.api_lib.compute import sg_tunnel_utils as sg_utils
from googlecloudsdk.core import exceptions
from googlecloudsdk.core import execution_utils
from googlecloudsdk.core import http_proxy
from googlecloudsdk.core import log
from googlecloudsdk.core import properties
from googlecloudsdk.core import transport
from googlecloudsdk.core.credentials import creds
from googlecloudsdk.core.credentials import store
from googlecloudsdk.core.util import files
from googlecloudsdk.core.util import platforms
import portpicker
import six
from six.moves import queue
if not platforms.OperatingSystem.IsWindows():
import fcntl # pylint: disable=g-import-not-at-top
else:
from ctypes import wintypes # pylint: disable=g-import-not-at-top
READ_FROM_STDIN_TIMEOUT_SECS = 3
class LocalPortUnavailableError(exceptions.Error):
pass
class UnableToOpenPortError(exceptions.Error):
pass
def _AddBaseArgs(parser):
parser.add_argument(
'--iap-tunnel-url-override',
hidden=True,
help=('Allows for overriding the connection endpoint for integration '
'testing.'))
parser.add_argument(
'--iap-tunnel-insecure-disable-websocket-cert-check',
default=False,
action='store_true',
hidden=True,
help='Disables checking certificates on the WebSocket connection.')
def AddSshTunnelArgs(parser, tunnel_through_iap_scope):
_AddBaseArgs(parser)
tunnel_through_iap_scope.add_argument(
'--tunnel-through-iap',
action='store_true',
help="""\
Tunnel the ssh connection through Cloud Identity-Aware Proxy for TCP
forwarding.
To learn more, see the
[IAP for TCP forwarding documentation](https://cloud.google.com/iap/docs/tcp-forwarding-overview).
""")
def AddHostBasedTunnelArgs(parser, support_security_gateway=False):
"""Add the arguments for supporting host-based connections."""
group = parser.add_argument_group()
group.add_argument(
'--region',
default=None,
required=True,
help=('Configures the region to use when connecting via IP address or '
'FQDN.'))
if support_security_gateway:
group_mutex = group.add_argument_group(mutex=True)
AddSecurityGatewayTunnelArgs(group_mutex.add_argument_group(hidden=True))
group = group_mutex.add_argument_group()
AddOnPremTunnelArgs(group)
def AddOnPremTunnelArgs(parser):
"""Add the arguments for supporting IP/FQDN-based tunnels."""
parser.add_argument(
'--network',
default=None,
required=True,
help=(
'Configures the VPC network to use when connecting via IP address or '
'FQDN.'))
# TODO(b/196572980): Make dest-group required in beta/GA.
parser.add_argument(
'--dest-group',
default=None,
required=False,
help=('Configures the destination group to use when connecting via IP '
'address or FQDN.'))
def AddSecurityGatewayTunnelArgs(parser):
"""Add arguments for the Security Gateway path."""
parser.add_argument(
'--security-gateway',
default=None,
required=True,
help='Configure the security gateway resource for connecting.')
# TODO(b/196572980): Make dest-group required in beta/GA.
parser.add_argument(
'--use-dest-group',
default=False,
action='store_true',
required=False,
help=('Configures the destination group to use when connecting via IP '
'address or FQDN.'))
def AddProxyServerHelperArgs(parser):
_AddBaseArgs(parser)
def CreateSshTunnelArgs(args, track, instance_ref, external_interface):
"""Construct an SshTunnelArgs from command line args and values.
Args:
args: The parsed commandline arguments. May or may not have had
AddSshTunnelArgs called.
track: ReleaseTrack, The currently running release track.
instance_ref: The target instance reference object.
external_interface: The external interface of target resource object, if
available, otherwise None.
Returns:
SshTunnelArgs or None if IAP Tunnel is disabled.
"""
# If tunneling through IAP is not available, then abort.
if not hasattr(args, 'tunnel_through_iap'):
return None
# If set to connect directly to private IP address, then abort.
if getattr(args, 'internal_ip', False):
return None
if args.IsSpecified('tunnel_through_iap'):
# If IAP tunneling is explicitly disabled, then abort.
if not args.tunnel_through_iap:
return None
else:
# If no external interface is available, then default to using IAP
# tunneling and continue with code below. Otherwise, abort.
if external_interface:
return None
log.status.Print('External IP address was not found; defaulting to using '
'IAP tunneling.')
res = SshTunnelArgs()
res.track = track.prefix
res.project = instance_ref.project
res.zone = instance_ref.zone
res.instance = instance_ref.instance
_AddPassThroughArgs(args, res)
return res
def CreateOnPremSshTunnelArgs(args, track, host):
"""Construct an SshTunnelArgs from command line args and values for on-prem.
Args:
args: The parsed commandline arguments. May or may not have had
AddSshTunnelArgs called.
track: ReleaseTrack, The currently running release track.
host: The target IP address or FQDN.
Returns:
SshTunnelArgs.
"""
res = SshTunnelArgs()
res.track = track.prefix
res.project = properties.VALUES.core.project.GetOrFail()
res.region = args.region
res.network = args.network
res.instance = host
_AddPassThroughArgs(args, res)
return res
def _AddPassThroughArgs(args, ssh_tunnel_args):
"""Adds any passthrough args to the SshTunnelArgs.
Args:
args: The parsed commandline arguments. May or may not have had
AddSshTunnelArgs called.
ssh_tunnel_args: SshTunnelArgs, The SSH tunnel args to update.
"""
if args.IsSpecified('iap_tunnel_url_override'):
ssh_tunnel_args.pass_through_args.append(
'--iap-tunnel-url-override=' + args.iap_tunnel_url_override)
if args.iap_tunnel_insecure_disable_websocket_cert_check:
ssh_tunnel_args.pass_through_args.append(
'--iap-tunnel-insecure-disable-websocket-cert-check')
if args.IsKnownAndSpecified('dest_group'):
ssh_tunnel_args.pass_through_args.append('--dest-group=' + args.dest_group)
class SshTunnelArgs(object):
"""A class to hold some options for IAP Tunnel SSH/SCP.
Attributes:
track: str/None, the prefix of the track for the inner gcloud.
project: str, the project id (string with dashes).
zone: str, the zone name.
instance: str, the instance name (or IP or FQDN for on-prem).
region: str, the region name (on-prem only).
network: str, the network name (on-prem only).
cloud_run_args: dict, The fields required to construct Cloud Run
SshTunnelArgs. If present, this field should contain fields for
'deployment_name', 'workload_type', and 'project_number'. Optionally can
contain 'instance_id' and 'container_id'.
pass_through_args: [str], additional args to be passed to the inner gcloud.
"""
def __init__(self):
self.track = None
self.project = ''
self.zone = ''
self.instance = ''
self.region = ''
self.network = ''
self.cloud_run_args = None
self.pass_through_args = []
def _Members(self):
return (
self.track,
self.project,
self.zone,
self.instance,
self.region,
self.network,
self.pass_through_args,
self.cloud_run_args,
)
def __eq__(self, other):
# pylint: disable=protected-access
return self._Members() == other._Members()
def __ne__(self, other):
return not self == other
def __repr__(self):
return 'SshTunnelArgs<%r>' % (self._Members(),)
def DetermineLocalPort(port_arg=0):
if not port_arg:
port_arg = portpicker.pick_unused_port()
if not portpicker.is_port_free(port_arg):
raise LocalPortUnavailableError('Local port [%d] is not available.' %
port_arg)
return port_arg
def _CloseLocalConnectionCallback(local_conn):
"""Callback function to close the local connection, if any."""
# For test WebSocket connections, there is not a local socket connection.
if local_conn:
try:
# Calling shutdown() first is needed to promptly notify the process on
# the other side of the connection that it is closing. This allows that
# other process, whether over TCP or stdin, to promptly terminate rather
# that waiting for the next time that the process tries to send data.
local_conn.shutdown(socket.SHUT_RDWR)
except EnvironmentError:
pass
try:
local_conn.close()
except EnvironmentError:
pass
def _GetAccessTokenCallback(credentials, lock):
"""Callback function to refresh credentials and return access token."""
if not credentials:
return None
log.debug('credentials type for _GetAccessTokenCallback is [%s].',
six.text_type(credentials))
with lock:
store.RefreshIfAlmostExpire(credentials)
if creds.IsGoogleAuthCredentials(credentials):
return credentials.token
else:
return credentials.access_token
def _SendLocalDataCallback(local_conn, data):
# For test WebSocket connections, there is not a local socket connection.
if local_conn:
local_conn.send(data)
def _OpenLocalTcpSockets(local_host, local_port):
"""Attempt to open a local socket(s) listening on specified host and port."""
open_sockets = []
for res in socket.getaddrinfo(
local_host, local_port, socket.AF_UNSPEC, socket.SOCK_STREAM, 0,
socket.AI_PASSIVE):
af, socktype, proto, unused_canonname, sock_addr = res
try:
s = socket.socket(af, socktype, proto)
except socket.error:
continue
try:
if not platforms.OperatingSystem.IsWindows():
# This allows us to restart quickly on the same port. See b/213858080.
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(sock_addr)
# Keep it large enough so it can handle burst of tunnel requests.
s.listen(128)
open_sockets.append(s)
except EnvironmentError:
try:
s.close()
except socket.error:
pass
if open_sockets:
return open_sockets
raise UnableToOpenPortError('Unable to open socket on port [%d].' %
local_port)
class _StdinSocket(object):
"""A wrapper around stdin/out that allows it to be treated like a socket.
Does not implement all socket functions. And of the ones implemented, not all
arguments/flags are supported. Once created, stdin should never be accessed by
anything else.
"""
class _StdinSocketMessage(object):
"""A class to wrap messages coming to the stdin socket for windows systems."""
def __init__(self, message_type, data):
self._type = message_type
self._data = data
def GetData(self):
return self._data
def GetType(self):
return self._type
class _EOFError(Exception):
pass
class _StdinClosedMessageType():
pass
class _ExceptionMessageType():
pass
class _DataMessageType():
pass
def __init__(self):
self._stdin_closed = False
# Maximum number of bytes the thread should read.
self._bufsize = utils.SUBPROTOCOL_MAX_DATA_FRAME_SIZE
if platforms.OperatingSystem.IsWindows():
# We will use this thread-safe queue to communicate with the input
# reading thread.
self._message_queue = queue.Queue()
self._reading_thread = threading.Thread(
target=self._ReadFromStdinAndEnqueueMessageWindows)
self._reading_thread.daemon = True
self._reading_thread.start()
else:
self._old_flags = fcntl.fcntl(sys.stdin, fcntl.F_GETFL)
# Set up non-blocking mode to avoid getting stuck on read.
fcntl.fcntl(sys.stdin, fcntl.F_SETFL, self._old_flags | os.O_NONBLOCK)
def __del__(self):
# We need to restore the flags, even when gcloud exits, nonblocking stdin
# causes weird problems in that terminal, such as cat stops working.
# This will happen if gcloud is running with stdin mode directly, then
# killed between the set nonblocking and the restore. The user could fix
# that by running bash and exiting it, or just closing the terminal.
# If gcloud is running as an ssh ProxyCommand this problem doesn't happen.
if not platforms.OperatingSystem.IsWindows():
fcntl.fcntl(sys.stdin, fcntl.F_SETFL, self._old_flags)
def send(self, data): # pylint: disable=invalid-name
files.WriteStreamBytes(sys.stdout, data)
if not six.PY2:
# WriteStreamBytes flushes python2 but not python3. Perhaps it should
# be modified to also flush python3.
sys.stdout.buffer.flush()
return len(data)
def recv(self, bufsize): # pylint: disable=invalid-name
"""Receives data from stdin.
Blocks until at least 1 byte is available.
On Unix (but not Windows) this is unblocked by close() and shutdown(RD).
On all platforms a signal handler triggering an exception will unblock this.
This cannot be called by multiple threads at the same time.
This function performs cleanups before returning, so killing gcloud while
this is running should be avoided. Specifically RaisesKeyboardInterrupt
should be in effect so that ctrl-c causes a clean exit with an exception
instead of triggering gcloud's default os.kill().
Args:
bufsize: The maximum number of bytes to receive. Must be positive.
Returns:
The bytes received. EOF is indicated by b''.
Raises:
IOError: On low level errors.
"""
if platforms.OperatingSystem.IsWindows():
return self._RecvWindows(bufsize)
else:
return self._RecvUnix(bufsize)
def close(self): # pylint: disable=invalid-name
# Closing stdin doesn't help, because it doesn't unblock read() calls.
# Also it causes problems, such as segfaulting in python2 and blocking in
# python3.
self.shutdown(socket.SHUT_RD)
def shutdown(self, how): # pylint: disable=invalid-name
# Shutting down read only (SHUT_RD)
if how in (socket.SHUT_RDWR, socket.SHUT_RD):
self._stdin_closed = True
# For windows we will unblock the thread early
if platforms.OperatingSystem.IsWindows():
# We queue the message so that the recv loop can abort
msg = self._StdinSocketMessage(self._StdinClosedMessageType, b'')
self._message_queue.put(msg)
def _ReadFromStdinAndEnqueueMessageWindows(self):
"""Reads data from stdin on Windows.
This method will loop until stdin is closed. Should be executed in a
separate thread to avoid blocking the main thread.
"""
try:
while not self._stdin_closed:
# STD_INPUT_HANDLE is -10
h = ctypes.windll.kernel32.GetStdHandle(-10)
buf = ctypes.create_string_buffer(self._bufsize)
number_of_bytes_read = wintypes.DWORD()
ok = ctypes.windll.kernel32.ReadFile(
h, buf, self._bufsize, ctypes.byref(number_of_bytes_read), None)
if not ok:
raise socket.error(errno.EIO, 'stdin ReadFile failed')
msg = buf.raw[:number_of_bytes_read.value]
self._message_queue.put(self._StdinSocketMessage(self._DataMessageType,
msg))
except Exception: # pylint: disable=broad-except
self._message_queue.put(
self._StdinSocketMessage(self._ExceptionMessageType,
sys.exc_info()))
def _RecvWindows(self, bufsize):
"""Reads data from stdin on Windows.
Args:
bufsize: The maximum number of bytes to receive. Must be positive.
Returns:
The bytes received. EOF is indicated by b''.
Raises:
socket.error: On low level errors.
"""
if bufsize != utils.SUBPROTOCOL_MAX_DATA_FRAME_SIZE:
log.info('bufsize [%s] is not max_data_frame_size', bufsize)
# We are using 1 second timeout here, which mean it can take up to 1
# second for gcloud to realize it should exit. Lower timeout means more
# cpu usage
while not self._stdin_closed:
try:
msg = self._message_queue.get(timeout=1)
except queue.Empty:
# Timeout reached
continue
msg_type = msg.GetType()
msg_data = msg.GetData()
if msg_type is self._ExceptionMessageType:
six.reraise(msg_data[0], msg_data[1], msg_data[2])
if msg_type is self._StdinClosedMessageType:
self._stdin_closed = True
return msg_data
# If stdin was closed we return an empty byte, so we can have a similar
# behavior as the unix version
return b''
def _RecvUnix(self, bufsize):
"""Reads data from stdin on Unix.
Args:
bufsize: The maximum number of bytes to receive. Must be positive.
Returns:
The bytes received. EOF is indicated by b''. Once EOF has been indicated,
will always indicate EOF.
Raises:
IOError: On low level errors.
"""
# We don't except bufsize to be anything other than the max size of our
# protocol. We will log it only for now to get telemetry if this
# ever happens.
if bufsize != utils.SUBPROTOCOL_MAX_DATA_FRAME_SIZE:
log.info('bufsize [%s] is not max_data_frame_size', bufsize)
if self._stdin_closed:
return b''
try:
while not self._stdin_closed:
# We have a timeout here because of b/197960494
stdin_ready = select.select([sys.stdin], (), (),
READ_FROM_STDIN_TIMEOUT_SECS)
if not stdin_ready[0]:
continue
return self._ReadUnixNonBlocking(self._bufsize)
except _StdinSocket._EOFError:
self._stdin_closed = True
return b''
def _ReadUnixNonBlocking(self, bufsize):
"""Reads from stdin on Unix in a nonblocking manner.
Args:
bufsize: The maximum number of bytes to receive. Must be positive.
Returns:
The bytes read. b'' means no data is available.
Raises:
_StdinSocket._EOFError: to indicate EOF.
IOError: On low level errors.
"""
# In python 3, we need to read stdin in a binary way, not a text way to
# read bytes instead of str. In python 2, binary mode vs text mode only
# matters on Windows.
try:
if six.PY2:
b = sys.stdin.read(bufsize)
else:
b = sys.stdin.buffer.read(bufsize)
except IOError as e:
if e.errno == errno.EAGAIN or isinstance(e, io.BlockingIOError):
# In python2, no nonblocking data available is indicated by raising
# IOError with EAGAIN.
# The online python3 documentation says BlockingIOError is raised when
# no nonblocking data available. We handle that case in case it is ever
# correct. BlockingIOError is a subclass of OSError which is identical
# to IOError.
return b''
raise
if b == b'': # pylint: disable=g-explicit-bool-comparison
# In python 2 and 3, EOF is indicated by returning b''.
raise _StdinSocket._EOFError
if b is None:
# Regardless of what the online python3 documentation says, it actually
# returns None to indicate no nonblocking data available.
b = b''
return b
class SecurityGatewayTunnelHelper(object):
"""Helper class for starting a Security Gateaway tunnel."""
def __init__(self, args, project, region, security_gateway, host, port,
use_dest_group=False):
# Re-use the same args as IAP to prevent adding more flags than necessary.
self._tunnel_url_override = args.iap_tunnel_url_override
self._ignore_certs = args.iap_tunnel_insecure_disable_websocket_cert_check
self._project = project
self._region = region
self._security_gateway = security_gateway
self._host = host
self._port = port
self._use_dest_group = use_dest_group
self._shutdown = False
self._credential = store.LoadIfEnabled(use_google_auth=True)
self._credential_lock = threading.Lock()
def _InitiateConnection(self, local_conn,
get_access_token_callback, user_agent):
del user_agent # Unused.
sg_tunnel_target = self._GetTargetInfo()
new_sg_tunnel = sg_tunnel.SecurityGatewayTunnel(
sg_tunnel_target,
get_access_token_callback,
functools.partial(_SendLocalDataCallback, local_conn),
functools.partial(_CloseLocalConnectionCallback, local_conn),
self._ignore_certs)
new_sg_tunnel.InitiateConnection()
return new_sg_tunnel
def _GetTargetInfo(self):
proxy_info = http_proxy.GetHttpProxyInfo()
if callable(proxy_info):
proxy_info = proxy_info(method='https')
return sg_utils.SecurityGatewayTargetInfo(
project=self._project,
region=self._region,
security_gateway=self._security_gateway,
host=self._host,
port=self._port,
url_override=self._tunnel_url_override,
proxy_info=proxy_info,
use_dest_group=self._use_dest_group,
)
def RunReceiveLocalData(self, local_conn, socket_address, user_agent,
conn_id=-1):
"""Receive data from provided local connection and send over HTTP CONNECT.
Args:
local_conn: A socket or _StdinSocket representing the local connection.
socket_address: A verbose loggable string describing where conn is
connected to.
user_agent: The user_agent of this connection
conn_id: The id of the connection.
"""
del conn_id # Unused.
sg_conn = None
try:
sg_conn = self._InitiateConnection(
local_conn,
functools.partial(
_GetAccessTokenCallback, self._credential, self._credential_lock
),
user_agent,
)
while not (self._shutdown or sg_conn.ShouldStop()):
data = local_conn.recv(utils.SUBPROTOCOL_MAX_DATA_FRAME_SIZE)
if not data:
log.warning('Local connection [%s] has closed.', socket_address)
break
sg_conn.Send(data)
except socket.error as e:
log.error('Error while transmitting local connection [%s]: %s ',
socket_address, e)
finally:
log.info('Terminating connection from local connection: [%s]',
socket_address)
if local_conn:
local_conn.shutdown(socket.SHUT_RD)
local_conn.close()
if sg_conn:
sg_conn.Close()
log.debug('Connection [%s] closed.', socket_address)
def Close(self):
# This is expected to be called from a separate thread than the one running
# RunReceiveLocalData.
self._shutdown = True
class IAPWebsocketTunnelHelper(object):
"""Helper class for starting an IAP WebSocket tunnel."""
def __init__(self, args, project,
zone=None, instance=None, interface=None, port=None,
region=None, network=None, host=None, dest_group=None):
self._project = project
self._iap_tunnel_url_override = args.iap_tunnel_url_override
self._ignore_certs = args.iap_tunnel_insecure_disable_websocket_cert_check
self._zone = zone
self._instance = instance
self._interface = interface
self._port = port
self._region = region
self._network = network
self._host = host
self._dest_group = dest_group
self._shutdown = False
self._credential = store.LoadIfEnabled(use_google_auth=True)
self._credential_lock = threading.Lock()
def Close(self):
self._shutdown = True
def _InitiateConnection(self, local_conn, get_access_token_callback,
user_agent, conn_id=-1):
tunnel_target = self._GetTunnelTargetInfo()
new_websocket = iap_tunnel_websocket.IapTunnelWebSocket(
tunnel_target, get_access_token_callback,
functools.partial(_SendLocalDataCallback, local_conn),
functools.partial(_CloseLocalConnectionCallback, local_conn),
user_agent, ignore_certs=self._ignore_certs,
conn_id=conn_id)
new_websocket.InitiateConnection()
return new_websocket
def _GetTunnelTargetInfo(self):
proxy_info = http_proxy.GetHttpProxyInfo()
if callable(proxy_info):
proxy_info = proxy_info(method='https')
return utils.IapTunnelTargetInfo(project=self._project,
zone=self._zone,
instance=self._instance,
interface=self._interface,
port=self._port,
url_override=self._iap_tunnel_url_override,
proxy_info=proxy_info,
region=self._region,
network=self._network,
host=self._host,
dest_group=self._dest_group)
def RunReceiveLocalData(self, conn, socket_address, user_agent, conn_id=0):
"""Receive data from provided local connection and send over WebSocket.
Args:
conn: A socket or _StdinSocket representing the local connection.
socket_address: A verbose loggable string describing where conn is
connected to.
user_agent: The user_agent of this connection
conn_id: Id of the connection.
"""
websocket_conn = None
try:
websocket_conn = self._InitiateConnection(
conn,
functools.partial(
_GetAccessTokenCallback, self._credential, self._credential_lock
),
user_agent,
conn_id=conn_id,
)
while not self._shutdown:
data = conn.recv(utils.SUBPROTOCOL_MAX_DATA_FRAME_SIZE)
if not data:
# When we recv an EOF, we notify the websocket_conn of it, then we
# wait for all data to send before returning.
websocket_conn.LocalEOF()
log.debug('[%d] Received local EOF, closing connection', conn_id)
if not websocket_conn.WaitForAllSent():
log.warning('[%d] Failed to send all data from [%s].',
conn_id, socket_address)
break
websocket_conn.Send(data)
except (Exception, exceptions.Error) as e: # pylint: disable=broad-exception-caught
log.exception('[%d] Error during local connection to [%s]: %s', conn_id,
socket_address, e)
finally:
if self._shutdown:
log.info('[%d] Terminating connection to [%s].',
conn_id, socket_address)
else:
log.info('[%d] Client closed connection from [%s].',
conn_id, socket_address)
try:
conn.close()
except EnvironmentError:
pass
try:
if websocket_conn:
websocket_conn.Close()
except (EnvironmentError, exceptions.Error):
pass
class IapTunnelProxyServerHelper():
"""Proxy server helper listens on a port for new local connections."""
def __init__(self, local_host, local_port,
should_test_connection, tunneler):
self._tunneler = tunneler
self._local_host = local_host
self._local_port = local_port
self._should_test_connection = should_test_connection
self._server_sockets = []
self._connections = []
self._total_connections = 0
def __del__(self):
self._CloseServerSockets()
def Run(self):
"""Start accepting connections."""
if self._should_test_connection:
try:
self._TestConnection()
except iap_tunnel_websocket.ConnectionCreationError as e:
raise iap_tunnel_websocket.ConnectionCreationError(
'While checking if a connection can be made: %s' % six.text_type(e))
self._server_sockets = _OpenLocalTcpSockets(self._local_host,
self._local_port)
log.out.Print('Listening on port [%d].' % self._local_port)
try:
with execution_utils.RaisesKeyboardInterrupt():
while True:
self._connections.append(self._AcceptNewConnection())
# To fix b/189195317, we will need to erase the reference of dead
# tasks.
self._CleanDeadClientConnections()
except KeyboardInterrupt:
log.info('Keyboard interrupt received.')
finally:
self._CloseServerSockets()
self._tunneler.Close()
self._CloseClientConnections()
log.status.Print('Server shutdown complete.')
def _TestConnection(self):
"""Test if a connection can be made to the requested endpoint."""
log.status.Print('Testing if tunnel connection works.')
user_agent = transport.MakeUserAgentString()
# pylint: disable=protected-access
conn = self._tunneler._InitiateConnection(
None,
functools.partial(
_GetAccessTokenCallback,
self._tunneler._credential,
threading.Lock(),
),
user_agent,
)
# pylint: enable=protected-access
conn.Close()
def _AcceptNewConnection(self):
"""Accept a new socket connection and start a new WebSocket tunnel."""
# Python socket accept() on Windows does not get interrupted by ctrl-c
# To work around that, use select() with a timeout before the accept()
# which allows for the ctrl-c to be noticed and abort the process as
# expected.
ready_sockets = [()]
while not ready_sockets[0]:
# 0.2 second timeout
ready_sockets = select.select(self._server_sockets, (), (), 0.2)
ready_read_sockets = ready_sockets[0]
conn, socket_address = ready_read_sockets[0].accept()
new_thread = threading.Thread(target=self._HandleNewConnection,
args=(conn, socket_address,
self._total_connections))
new_thread.daemon = True
new_thread.start()
self._total_connections += 1
return new_thread, conn
def _CloseServerSockets(self):
log.debug('Stopping server.')
try:
for server_socket in self._server_sockets:
server_socket.close()
except EnvironmentError:
pass
def _CloseClientConnections(self):
"""Close client connections that seem to still be open."""
if self._connections:
close_count = 0
for client_thread, conn in self._connections:
if client_thread.is_alive():
close_count += 1
try:
conn.close()
except EnvironmentError:
pass
if close_count:
log.status.Print('Closed [%d] local connection(s).' % close_count)
def _CleanDeadClientConnections(self):
"""Erase reference to dead connections so they can be garbage collected."""
conn_still_alive = []
if self._connections:
dead_connections = 0
for client_thread, conn in self._connections:
if not client_thread.is_alive():
dead_connections += 1
try:
conn.close()
except EnvironmentError:
pass
del conn
del client_thread
else:
conn_still_alive.append([client_thread, conn])
if dead_connections:
log.debug('Cleaned [%d] dead connection(s).' % dead_connections)
self._connections = conn_still_alive
# We run GC mostly for windows platforms, where it seems GC is not
# collecting memory quick enough. For linux platforms, this is needed only
# to immediately clean the memory we freed above.
gc.collect(2)
log.debug('connections alive: [%d]' % len(self._connections))
def _HandleNewConnection(self, conn, socket_address, conn_id):
try:
user_agent = transport.MakeUserAgentString()
self._tunneler.RunReceiveLocalData(conn, repr(socket_address),
user_agent, conn_id=conn_id)
except EnvironmentError as e:
log.info('Socket error [%s] while receiving from client.',
six.text_type(e))
except: # pylint: disable=bare-except
log.exception('Error while receiving from client.')
class IapTunnelStdinHelper():
"""Facilitates a connection that gets local data from stdin."""
def __init__(self, tunneler):
self._tunneler = tunneler
def Run(self):
"""Executes the tunneling of data."""
try:
with execution_utils.RaisesKeyboardInterrupt():
# Fetching user agent before we start the read loop, because the agent
# fetch will call sys.stdin.isatty, which is blocking if there is a read
# waiting for data in the stdin. This only affects MacOs + python 2.7.
user_agent = transport.MakeUserAgentString()
self._tunneler.RunReceiveLocalData(_StdinSocket(), 'stdin', user_agent)
except KeyboardInterrupt:
log.info('Keyboard interrupt received.')