diff --git a/packages/google-auth/google/auth/compute_engine/_metadata.py b/packages/google-auth/google/auth/compute_engine/_metadata.py index aae724ab18ee..f7fe18a68660 100644 --- a/packages/google-auth/google/auth/compute_engine/_metadata.py +++ b/packages/google-auth/google/auth/compute_engine/_metadata.py @@ -22,7 +22,7 @@ import json import logging import os -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse import requests @@ -52,39 +52,70 @@ ) -def _validate_gce_mds_configured_environment(): - """Validates the GCE metadata server environment configuration for mTLS. +def _validate_gce_mds_configured_environment(mode: _mtls.MdsMtlsMode, mds_url: str): + """Validates that the environment is properly configured for GCE MDS if mTLS is enabled. mTLS is only supported when connecting to the default metadata server hosts. If we are in strict mode (which requires mTLS), ensure that the metadata host has not been overridden to a custom value (which means mTLS will fail). + Args: + mode (_mtls.MdsMtlsMode): The mTLS mode configured for the metadata server, parsed from the GCE_METADATA_MTLS_MODE environment variable. + mds_url (str): The metadata server URL to which the request will be made. Raises: google.auth.exceptions.MutualTLSChannelError: if the environment configuration is invalid for mTLS. """ - mode = _mtls._parse_mds_mode() if mode == _mtls.MdsMtlsMode.STRICT: # mTLS is only supported when connecting to the default metadata host. # Raise an exception if we are in strict mode (which requires mTLS) # but the metadata host has been overridden to a custom MDS. (which means mTLS will fail) - if _GCE_METADATA_HOST not in _GCE_DEFAULT_MDS_HOSTS: + parsed = urlparse(mds_url) + if parsed.hostname not in _GCE_DEFAULT_MDS_HOSTS: raise exceptions.MutualTLSChannelError( "Mutual TLS is required, but the metadata host has been overridden. " "mTLS is only supported when connecting to the default metadata host." ) + if parsed.scheme != "https": + raise exceptions.MutualTLSChannelError( + "Mutual TLS is required, but the metadata URL scheme is not HTTPS. " + "mTLS requires HTTPS." + ) -def _get_metadata_root(use_mtls: bool): - """Returns the metadata server root URL.""" +def _get_metadata_root( + mds_mtls_mode: _mtls.MdsMtlsMode, mds_mtls_adapter_mounted: bool +) -> str: + """Returns the metadata server root URL, with the appropriate scheme based on mTLS configuration. + + Args: + mds_mtls_mode (_mtls.MdsMtlsMode): The mTLS mode configured for the metadata server, parsed from the GCE_METADATA_MTLS_MODE environment variable. + mds_mtls_adapter_mounted (bool): Whether the mTLS adapter was successfully mounted to the request's session. + Returns: + str: The metadata server root URL. The URL will use HTTPS if mTLS is enabled or required, and HTTP otherwise. + """ - scheme = "https" if use_mtls else "http" + scheme = "http" + if mds_mtls_adapter_mounted or mds_mtls_mode == _mtls.MdsMtlsMode.STRICT: + scheme = "https" return "{}://{}/computeMetadata/v1/".format(scheme, _GCE_METADATA_HOST) -def _get_metadata_ip_root(use_mtls: bool): - """Returns the metadata server IP root URL.""" - scheme = "https" if use_mtls else "http" +def _get_metadata_ip_root( + mds_mtls_mode: _mtls.MdsMtlsMode, mds_mtls_adapter_mounted: bool +) -> str: + """Returns the metadata server IP root URL, with the appropriate scheme based on mTLS configuration. + + Args: + mds_mtls_mode (_mtls.MdsMtlsMode): The mTLS mode configured for the metadata server, parsed from the GCE_METADATA_MTLS_MODE environment variable. + mds_mtls_adapter_mounted (bool): Whether the mTLS adapter was successfully mounted to the request's session. + Returns: + str: The metadata server IP root URL. The URL will use HTTPS if mTLS is enabled or required, and HTTP otherwise. + """ + + scheme = "http" + if mds_mtls_adapter_mounted or mds_mtls_mode == _mtls.MdsMtlsMode.STRICT: + scheme = "https" return "{}://{}".format( scheme, os.getenv(environment_vars.GCE_METADATA_IP, _GCE_DEFAULT_MDS_IP) ) @@ -159,30 +190,38 @@ def detect_gce_residency_linux(): return content.startswith(_GOOGLE) -def _prepare_request_for_mds(request, use_mtls=False) -> None: - """Prepares a request for the metadata server. - - This will check if mTLS should be used and mount the mTLS adapter if needed. +def _try_mount_mds_mtls_adapter(request, mode: _mtls.MdsMtlsMode) -> bool: + """Tries to mount the mTLS adapter to the request's session if mTLS is enabled and certificates are present. Args: request (google.auth.transport.Request): A callable used to make HTTP requests. If mTLS is enabled, and the request supports sessions, the request will have the mTLS adapter mounted. Otherwise, there will be no change. - use_mtls (bool): Whether to use mTLS for the request. + mode (_mtls.MdsMtlsMode): The mTLS mode configured for the metadata server, parsed from the GCE_METADATA_MTLS_MODE environment variable. + Returns: + bool: True if the mTLS adapter was mounted, False otherwise. + """ + mds_mtls_config = _mtls.MdsMtlsConfig() + should_mount_adapter = _mtls.should_use_mds_mtls( + mode, mds_mtls_config=mds_mtls_config + ) - """ # Only modify the request if mTLS is enabled, and request supports sessions. - if use_mtls and hasattr(request, "session"): + mds_mtls_adapter_mounted = False + if should_mount_adapter and hasattr(request, "session"): # Ensure the request has a session to mount the adapter to. if not request.session: request.session = requests.Session() - adapter = _mtls.MdsMtlsAdapter() + adapter = _mtls.MdsMtlsAdapter(mds_mtls_config=mds_mtls_config) # Mount the adapter for all default GCE metadata hosts. for host in _GCE_DEFAULT_MDS_HOSTS: request.session.mount(f"https://{host}/", adapter) + mds_mtls_adapter_mounted = True + + return mds_mtls_adapter_mounted def ping( @@ -200,8 +239,12 @@ def ping( Returns: bool: True if the metadata server is reachable, False otherwise. """ - use_mtls = _mtls.should_use_mds_mtls() - _prepare_request_for_mds(request, use_mtls=use_mtls) + mds_mtls_mode = _mtls._parse_mds_mode() + mds_mtls_adapter_mounted = _try_mount_mds_mtls_adapter(request, mds_mtls_mode) + + metadata_ip_root = _get_metadata_ip_root(mds_mtls_mode, mds_mtls_adapter_mounted) + _validate_gce_mds_configured_environment(mds_mtls_mode, metadata_ip_root) + # NOTE: The explicit ``timeout`` is a workaround. The underlying # issue is that resolving an unknown host on some networks will take # 20-30 seconds; making this timeout short fixes the issue, but @@ -216,7 +259,7 @@ def ping( for attempt in backoff: try: response = request( - url=_get_metadata_ip_root(use_mtls), + url=metadata_ip_root, method="GET", headers=headers, timeout=timeout, @@ -285,18 +328,16 @@ def get( has been overridden in strict mTLS mode). """ - use_mtls = _mtls.should_use_mds_mtls() - # Prepare the request object for mTLS if needed. - # This will create a new request object with the mTLS session. - _prepare_request_for_mds(request, use_mtls=use_mtls) + mds_mtls_mode = _mtls._parse_mds_mode() + mds_mtls_adapter_mounted = _try_mount_mds_mtls_adapter(request, mds_mtls_mode) if root is None: - root = _get_metadata_root(use_mtls) + root = _get_metadata_root(mds_mtls_mode, mds_mtls_adapter_mounted) # mTLS is only supported when connecting to the default metadata host. # If we are in strict mode (which requires mTLS), ensure that the metadata host # has not been overridden to a non-default host value (which means mTLS will fail). - _validate_gce_mds_configured_environment() + _validate_gce_mds_configured_environment(mds_mtls_mode, root) base_url = urljoin(root, path) query_params = {} if params is None else params diff --git a/packages/google-auth/google/auth/compute_engine/_mtls.py b/packages/google-auth/google/auth/compute_engine/_mtls.py index 6525dd03e1bd..21e740f2b163 100644 --- a/packages/google-auth/google/auth/compute_engine/_mtls.py +++ b/packages/google-auth/google/auth/compute_engine/_mtls.py @@ -65,8 +65,16 @@ class MdsMtlsConfig: ) # path to file containing client certificate and key -def _certs_exist(mds_mtls_config: MdsMtlsConfig): - """Checks if the mTLS certificates exist.""" +def mds_mtls_certificates_exist(mds_mtls_config: MdsMtlsConfig): + """Checks if the mTLS certificates exist. + + Args: + mds_mtls_config (MdsMtlsConfig): The mTLS configuration containing the + paths to the CA and client certificates. + + Returns: + bool: True if both certificates exist, False otherwise. + """ return os.path.exists(mds_mtls_config.ca_cert_path) and os.path.exists( mds_mtls_config.client_combined_cert_path ) @@ -98,19 +106,33 @@ def _parse_mds_mode(): ) -def should_use_mds_mtls(mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig()): - """Determines if mTLS should be used for the metadata server.""" - mode = _parse_mds_mode() +def should_use_mds_mtls( + mode: MdsMtlsMode, mds_mtls_config: MdsMtlsConfig = MdsMtlsConfig() +) -> bool: + """Determines if mTLS should be used for the metadata server. + + Args: + mode (MdsMtlsMode): The mTLS mode configured for the metadata server, + parsed from the GCE_METADATA_MTLS_MODE environment variable. + mds_mtls_config (MdsMtlsConfig): The mTLS configuration containing the + paths to the CA and client certificates. + + Returns: + bool: True if mTLS should be used, False otherwise. + + Raises: + google.auth.exceptions.MutualTLSChannelError: if mTLS is required (STRICT mode) + but certificates are missing. + """ if mode == MdsMtlsMode.STRICT: - if not _certs_exist(mds_mtls_config): + if not mds_mtls_certificates_exist(mds_mtls_config): raise exceptions.MutualTLSChannelError( "mTLS certificates not found in strict mode." ) return True - elif mode == MdsMtlsMode.NONE: + if mode == MdsMtlsMode.NONE: return False - else: # Default mode - return _certs_exist(mds_mtls_config) + return mds_mtls_certificates_exist(mds_mtls_config) class MdsMtlsAdapter(HTTPAdapter): diff --git a/packages/google-auth/tests/compute_engine/test__metadata.py b/packages/google-auth/tests/compute_engine/test__metadata.py index 35996ab24b92..eec7d433504c 100644 --- a/packages/google-auth/tests/compute_engine/test__metadata.py +++ b/packages/google-auth/tests/compute_engine/test__metadata.py @@ -638,12 +638,18 @@ def test_get_universe_domain_other_error(): ) +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate", + return_value=None, +) @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, ) @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) -def test_get_service_account_token(utcnow, mock_metrics_header_value): +def test_get_service_account_token( + utcnow, mock_metrics_header_value, mock_get_agent_cert +): ttl = 500 request = make_request( json.dumps({"access_token": "token", "expires_in": ttl}), @@ -665,12 +671,18 @@ def test_get_service_account_token(utcnow, mock_metrics_header_value): assert expiry == utcnow() + datetime.timedelta(seconds=ttl) +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate", + return_value=None, +) @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, ) @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) -def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_value): +def test_get_service_account_token_with_scopes_list( + utcnow, mock_metrics_header_value, mock_get_agent_cert +): ttl = 500 request = make_request( json.dumps({"access_token": "token", "expires_in": ttl}), @@ -695,13 +707,17 @@ def test_get_service_account_token_with_scopes_list(utcnow, mock_metrics_header_ assert expiry == utcnow() + datetime.timedelta(seconds=ttl) +@mock.patch( + "google.auth._agent_identity_utils.get_and_parse_agent_identity_certificate", + return_value=None, +) @mock.patch( "google.auth.metrics.token_request_access_token_mds", return_value=ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE, ) @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test_get_service_account_token_with_scopes_string( - utcnow, mock_metrics_header_value + utcnow, mock_metrics_header_value, mock_get_agent_cert ): ttl = 500 request = make_request( @@ -826,46 +842,70 @@ def test_get_service_account_info(): def test__get_metadata_root_mtls(): assert ( - _metadata._get_metadata_root(use_mtls=True) + _metadata._get_metadata_root( + _metadata._mtls.MdsMtlsMode.STRICT, mds_mtls_adapter_mounted=False + ) == "https://metadata.google.internal/computeMetadata/v1/" ) def test__get_metadata_root_no_mtls(): assert ( - _metadata._get_metadata_root(use_mtls=False) + _metadata._get_metadata_root( + _metadata._mtls.MdsMtlsMode.NONE, mds_mtls_adapter_mounted=False + ) == "http://metadata.google.internal/computeMetadata/v1/" ) def test__get_metadata_ip_root_mtls(): - assert _metadata._get_metadata_ip_root(use_mtls=True) == "https://169.254.169.254" + assert ( + _metadata._get_metadata_ip_root( + _metadata._mtls.MdsMtlsMode.STRICT, mds_mtls_adapter_mounted=False + ) + == "https://169.254.169.254" + ) def test__get_metadata_ip_root_no_mtls(): - assert _metadata._get_metadata_ip_root(use_mtls=False) == "http://169.254.169.254" + assert ( + _metadata._get_metadata_ip_root( + _metadata._mtls.MdsMtlsMode.NONE, mds_mtls_adapter_mounted=False + ) + == "http://169.254.169.254" + ) +@mock.patch("google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls(mock_mds_mtls_adapter, mock_certs_exist): request = google_auth_requests.Request(mock.create_autospec(requests.Session)) - _metadata._prepare_request_for_mds(request, use_mtls=True) + assert _metadata._try_mount_mds_mtls_adapter( + request, mode=_metadata._mtls.MdsMtlsMode.STRICT + ) mock_mds_mtls_adapter.assert_called_once() assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) -def test__prepare_request_for_mds_no_mtls(): +def test__try_mount_mds_mtls_adapter_no_mtls(): request = mock.Mock() - _metadata._prepare_request_for_mds(request, use_mtls=False) + assert not _metadata._try_mount_mds_mtls_adapter( + request, mode=_metadata._mtls.MdsMtlsMode.NONE + ) request.session.mount.assert_not_called() @mock.patch("google.auth.metrics.mds_ping", return_value=MDS_PING_METRICS_HEADER_VALUE) -@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch( + "google.auth.compute_engine._metadata._try_mount_mds_mtls_adapter", return_value=True +) +@mock.patch( + "google.auth.compute_engine._mtls._parse_mds_mode", + return_value=_metadata._mtls.MdsMtlsMode.STRICT, +) @mock.patch("google.auth.transport.requests.Request") def test_ping_mtls( - mock_request, mock_should_use_mtls, mock_mds_mtls_adapter, mock_metrics_header_value + mock_request, mock_parse_mds_mode, mock_try_mount_adapter, mock_metrics_header_value ): response = mock.create_autospec(transport.Response, instance=True) response.status = http_client.OK @@ -874,8 +914,8 @@ def test_ping_mtls( assert _metadata.ping(mock_request) - mock_should_use_mtls.assert_called_once() - mock_mds_mtls_adapter.assert_called_once() + mock_parse_mds_mode.assert_called_once() + mock_try_mount_adapter.assert_called_once() mock_request.assert_called_once_with( url="https://169.254.169.254", method="GET", @@ -884,10 +924,15 @@ def test_ping_mtls( ) -@mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -@mock.patch("google.auth.compute_engine._mtls.should_use_mds_mtls", return_value=True) +@mock.patch( + "google.auth.compute_engine._metadata._try_mount_mds_mtls_adapter", return_value=True +) +@mock.patch( + "google.auth.compute_engine._mtls._parse_mds_mode", + return_value=_metadata._mtls.MdsMtlsMode.STRICT, +) @mock.patch("google.auth.transport.requests.Request") -def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): +def test_get_mtls(mock_request, mock_parse_mds_mode, mock_try_mount_adapter): response = mock.create_autospec(transport.Response, instance=True) response.status = http_client.OK response.data = _helpers.to_bytes("{}") @@ -896,8 +941,8 @@ def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): _metadata.get(mock_request, "some/path") - mock_should_use_mtls.assert_called_once() - mock_mds_mtls_adapter.assert_called_once() + mock_parse_mds_mode.assert_called_once() + mock_try_mount_adapter.assert_called_once() mock_request.assert_called_once_with( url="https://metadata.google.internal/computeMetadata/v1/some/path", method="GET", @@ -907,58 +952,82 @@ def test_get_mtls(mock_request, mock_should_use_mtls, mock_mds_mtls_adapter): @pytest.mark.parametrize( - "mds_mode, metadata_host, expect_exception", + "mds_mode, metadata_url, expect_exception", [ - (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_HOST, False), - (_metadata._mtls.MdsMtlsMode.STRICT, _metadata._GCE_DEFAULT_MDS_IP, False), - (_metadata._mtls.MdsMtlsMode.STRICT, "custom.host", True), - (_metadata._mtls.MdsMtlsMode.NONE, "custom.host", False), - (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_HOST, False), - (_metadata._mtls.MdsMtlsMode.DEFAULT, _metadata._GCE_DEFAULT_MDS_IP, False), + ( + _metadata._mtls.MdsMtlsMode.STRICT, + "https://" + _metadata._GCE_DEFAULT_HOST, + False, + ), + ( + _metadata._mtls.MdsMtlsMode.STRICT, + "https://" + _metadata._GCE_DEFAULT_MDS_IP, + False, + ), + (_metadata._mtls.MdsMtlsMode.STRICT, "https://custom.host", True), + (_metadata._mtls.MdsMtlsMode.STRICT, "http://metadata.google.internal", True), + (_metadata._mtls.MdsMtlsMode.NONE, "https://custom.host", False), + ( + _metadata._mtls.MdsMtlsMode.DEFAULT, + "https://" + _metadata._GCE_DEFAULT_HOST, + False, + ), + ( + _metadata._mtls.MdsMtlsMode.DEFAULT, + "https://" + _metadata._GCE_DEFAULT_MDS_IP, + False, + ), ], ) -@mock.patch("google.auth.compute_engine._mtls._parse_mds_mode") def test_validate_gce_mds_configured_environment( - mock_parse_mds_mode, mds_mode, metadata_host, expect_exception + mds_mode, metadata_url, expect_exception ): - mock_parse_mds_mode.return_value = mds_mode - with mock.patch( - "google.auth.compute_engine._metadata._GCE_METADATA_HOST", new=metadata_host - ): - if expect_exception: - with pytest.raises(exceptions.MutualTLSChannelError): - _metadata._validate_gce_mds_configured_environment() - else: - _metadata._validate_gce_mds_configured_environment() - mock_parse_mds_mode.assert_called_once() + if expect_exception: + with pytest.raises(exceptions.MutualTLSChannelError): + _metadata._validate_gce_mds_configured_environment(mds_mode, metadata_url) + else: + _metadata._validate_gce_mds_configured_environment(mds_mode, metadata_url) +@mock.patch("google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls_session_exists(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls_session_exists( + mock_mds_mtls_adapter, mock_certs_exist +): mock_session = mock.create_autospec(requests.Session) request = google_auth_requests.Request(mock_session) - _metadata._prepare_request_for_mds(request, use_mtls=True) + assert _metadata._try_mount_mds_mtls_adapter( + request, mode=_metadata._mtls.MdsMtlsMode.STRICT + ) mock_mds_mtls_adapter.assert_called_once() assert mock_session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) +@mock.patch("google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls_no_session(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls_no_session( + mock_mds_mtls_adapter, mock_certs_exist +): request = google_auth_requests.Request(None) # Explicitly set session to None to avoid a session being created in the Request constructor. request.session = None with mock.patch("requests.Session") as mock_session_class: - _metadata._prepare_request_for_mds(request, use_mtls=True) + assert _metadata._try_mount_mds_mtls_adapter( + request, mode=_metadata._mtls.MdsMtlsMode.STRICT + ) mock_session_class.assert_called_once() mock_mds_mtls_adapter.assert_called_once() assert request.session.mount.call_count == len(_metadata._GCE_DEFAULT_MDS_HOSTS) +@mock.patch("google.auth.compute_engine._mtls.mds_mtls_certificates_exist", return_value=True) @mock.patch("google.auth.compute_engine._mtls.MdsMtlsAdapter") -def test__prepare_request_for_mds_mtls_http_request(mock_mds_mtls_adapter): +def test__try_mount_mds_mtls_adapter_mtls_http_request( + mock_mds_mtls_adapter, mock_certs_exist +): """ http requests should be ignored. Regression test for https://github.com/googleapis/google-cloud-python/issues/16035 @@ -966,6 +1035,8 @@ def test__prepare_request_for_mds_mtls_http_request(mock_mds_mtls_adapter): from google.auth.transport import _http_client request = _http_client.Request() - _metadata._prepare_request_for_mds(request, use_mtls=True) + assert not _metadata._try_mount_mds_mtls_adapter( + request, mode=_metadata._mtls.MdsMtlsMode.STRICT + ) assert mock_mds_mtls_adapter.call_count == 0 diff --git a/packages/google-auth/tests/compute_engine/test__mtls.py b/packages/google-auth/tests/compute_engine/test__mtls.py index 6b40b6682869..dbf51996716c 100644 --- a/packages/google-auth/tests/compute_engine/test__mtls.py +++ b/packages/google-auth/tests/compute_engine/test__mtls.py @@ -79,40 +79,39 @@ def test__parse_mds_mode_invalid(monkeypatch): @mock.patch("os.path.exists") -def test__certs_exist_true(mock_exists, mock_mds_mtls_config): +def test_mds_mtls_certificates_exist_true(mock_exists, mock_mds_mtls_config): mock_exists.return_value = True - assert _mtls._certs_exist(mock_mds_mtls_config) is True + assert _mtls.mds_mtls_certificates_exist(mock_mds_mtls_config) is True @mock.patch("os.path.exists") -def test__certs_exist_false(mock_exists, mock_mds_mtls_config): +def test_mds_mtls_certificates_exist_false(mock_exists, mock_mds_mtls_config): mock_exists.return_value = False - assert _mtls._certs_exist(mock_mds_mtls_config) is False + assert _mtls.mds_mtls_certificates_exist(mock_mds_mtls_config) is False @pytest.mark.parametrize( "mtls_mode, certs_exist, expected_result", [ - ("strict", True, True), - ("strict", False, exceptions.MutualTLSChannelError), - ("none", True, False), - ("none", False, False), - ("default", True, True), - ("default", False, False), + (_mtls.MdsMtlsMode.STRICT, True, True), + (_mtls.MdsMtlsMode.STRICT, False, exceptions.MutualTLSChannelError), + (_mtls.MdsMtlsMode.NONE, True, False), + (_mtls.MdsMtlsMode.NONE, False, False), + (_mtls.MdsMtlsMode.DEFAULT, True, True), + (_mtls.MdsMtlsMode.DEFAULT, False, False), ], ) @mock.patch("os.path.exists") def test_should_use_mds_mtls( - mock_exists, monkeypatch, mtls_mode, certs_exist, expected_result + mock_exists, mtls_mode, certs_exist, expected_result, mock_mds_mtls_config ): - monkeypatch.setenv(environment_vars.GCE_METADATA_MTLS_MODE, mtls_mode) mock_exists.return_value = certs_exist if isinstance(expected_result, type) and issubclass(expected_result, Exception): with pytest.raises(expected_result): - _mtls.should_use_mds_mtls() + _mtls.should_use_mds_mtls(mtls_mode, mock_mds_mtls_config) else: - assert _mtls.should_use_mds_mtls() is expected_result + assert _mtls.should_use_mds_mtls(mtls_mode, mock_mds_mtls_config) is expected_result @mock.patch("ssl.create_default_context")