Skip to content
17 changes: 0 additions & 17 deletions src/databricks/sql/auth/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,6 @@
logger = logging.getLogger(__name__)


def parse_hostname(hostname: str) -> str:
"""
Normalize the hostname to include scheme and trailing slash.
Args:
hostname: The hostname to normalize
Returns:
Normalized hostname with scheme and trailing slash
"""
if not hostname.startswith("http://") and not hostname.startswith("https://"):
hostname = f"https://{hostname}"
if not hostname.endswith("/"):
hostname = f"{hostname}/"
return hostname


def decode_token(access_token: str) -> Optional[Dict]:
"""
Decode a JWT token without verification to extract claims.
Expand Down
6 changes: 3 additions & 3 deletions src/databricks/sql/auth/token_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from databricks.sql.auth.authenticators import AuthProvider
from databricks.sql.auth.auth_utils import (
parse_hostname,
decode_token,
is_same_host,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol
from databricks.sql.common.http import HttpMethod

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
if not http_client:
raise ValueError("http_client is required for TokenFederationProvider")

self.hostname = parse_hostname(hostname)
self.hostname = normalize_host_with_protocol(hostname)
self.external_provider = external_provider
self.http_client = http_client
self.identity_federation_client_id = identity_federation_client_id
Expand Down Expand Up @@ -164,7 +164,7 @@ def _should_exchange_token(self, access_token: str) -> bool:

def _exchange_token(self, access_token: str) -> Token:
"""Exchange the external token for a Databricks token."""
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"
token_url = f"{self.hostname}{self.TOKEN_EXCHANGE_ENDPOINT}"

data = {
"grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE,
Expand Down
6 changes: 4 additions & 2 deletions src/databricks/sql/backend/sea/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from databricks.sql.common.http_utils import (
detect_and_parse_proxy,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -66,8 +67,9 @@ def __init__(
self.auth_provider = auth_provider
self.ssl_options = ssl_options

# Build base URL
self.base_url = f"https://{server_hostname}:{self.port}"
# Build base URL using url_utils for consistent normalization
normalized_host = normalize_host_with_protocol(server_hostname)
self.base_url = f"{normalized_host}:{self.port}"

# Parse URL for proxy handling
parsed_url = urllib.parse.urlparse(self.base_url)
Expand Down
4 changes: 3 additions & 1 deletion src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, Optional, List, Any, TYPE_CHECKING

from databricks.sql.common.http import HttpMethod
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -67,7 +68,8 @@ def __init__(

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
normalize_host_with_protocol(self._connection.session.host)
+ endpoint_suffix
)

# Use the provided HTTP client
Expand Down
45 changes: 45 additions & 0 deletions src/databricks/sql/common/url_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
URL utility functions for the Databricks SQL connector.
"""


def normalize_host_with_protocol(host: str) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was rechecking this piece of code : auth_utils.py and thrift_backend - Has proper check on this already. should we use this util instead?

Also, the sea flow looks incorrect at the moment : backend/sea/utils/http_client.py

Copy link
Contributor Author

@nikhilsuri-db nikhilsuri-db Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we can consolidate the auth_util to url_utils 👍 but thrift_backend seems very inline modification, do not want to touch it in current PR scope.

Regarding Sea - Fixed by using the new url_utils

"""
Normalize a connection hostname by ensuring it has a protocol.

This is useful for handling cases where users may provide hostnames with or without protocols
(common with dbt-databricks users copying URLs from their browser).

Args:
host: Connection hostname which may or may not include a protocol prefix (https:// or http://)
and may or may not have a trailing slash

Returns:
Normalized hostname with protocol prefix and no trailing slashes

Examples:
normalize_host_with_protocol("myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("https://myserver.com") -> "https://myserver.com"
normalize_host_with_protocol("HTTPS://myserver.com/") -> "https://myserver.com"
normalize_host_with_protocol("http://localhost:8080/") -> "http://localhost:8080"

Raises:
ValueError: If host is None or empty string
"""
# Handle None or empty host
if not host or not host.strip():
raise ValueError("Host cannot be None or empty")

# Remove trailing slashes
host = host.rstrip("/")

# Add protocol if not present (case-insensitive check)
host_lower = host.lower()
if not host_lower.startswith("https://") and not host_lower.startswith("http://"):
host = f"https://{host}"
elif host_lower.startswith("https://") or host_lower.startswith("http://"):
# Normalize protocol to lowercase
protocol_end = host.index("://") + 3
host = host[:protocol_end].lower() + host[protocol_end:]

return host
3 changes: 2 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
TelemetryPushClient,
CircuitBreakerTelemetryPushClient,
)
from databricks.sql.common.url_utils import normalize_host_with_protocol

if TYPE_CHECKING:
from databricks.sql.client import Connection
Expand Down Expand Up @@ -278,7 +279,7 @@ def _send_telemetry(self, events):
if self._auth_provider
else self.TELEMETRY_UNAUTHENTICATED_PATH
)
url = f"https://{self._host_url}{path}"
url = normalize_host_with_protocol(self._host_url) + path

headers = {"Accept": "application/json", "Content-Type": "application/json"}

Expand Down
79 changes: 58 additions & 21 deletions tests/e2e/test_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,34 @@
from databricks.sql.telemetry.circuit_breaker_manager import CircuitBreakerManager


def wait_for_circuit_state(circuit_breaker, expected_states, timeout=5):
"""
Wait for circuit breaker to reach one of the expected states with polling.

Args:
circuit_breaker: The circuit breaker instance to monitor
expected_states: List of acceptable states
(STATE_OPEN, STATE_CLOSED, STATE_HALF_OPEN)
timeout: Maximum time to wait in seconds

Returns:
True if state reached, False if timeout

Examples:
# Single state - pass list with one element
wait_for_circuit_state(cb, [STATE_OPEN])

# Multiple states
wait_for_circuit_state(cb, [STATE_CLOSED, STATE_HALF_OPEN])
"""
start = time.time()
while time.time() - start < timeout:
if circuit_breaker.current_state in expected_states:
return True
time.sleep(0.1) # Poll every 100ms
return False


@pytest.fixture(autouse=True)
def aggressive_circuit_breaker_config():
"""
Expand Down Expand Up @@ -65,12 +93,17 @@ def create_mock_response(self, status_code):
}.get(status_code, b"Response")
return response

@pytest.mark.parametrize("status_code,should_trigger", [
(429, True),
(503, True),
(500, False),
])
def test_circuit_breaker_triggers_for_rate_limit_codes(self, status_code, should_trigger):
@pytest.mark.parametrize(
"status_code,should_trigger",
[
(429, True),
(503, True),
(500, False),
],
)
def test_circuit_breaker_triggers_for_rate_limit_codes(
self, status_code, should_trigger
):
"""
Verify circuit breaker opens for rate-limit codes (429/503) but not others (500).
"""
Expand Down Expand Up @@ -107,9 +140,14 @@ def mock_request(*args, **kwargs):
time.sleep(0.5)

if should_trigger:
# Circuit should be OPEN after 2 rate-limit failures
# Wait for circuit to open (async telemetry may take time)
assert wait_for_circuit_state(
circuit_breaker, [STATE_OPEN], timeout=5
), f"Circuit didn't open within 5s, state: {circuit_breaker.current_state}"

# Circuit should be OPEN after rate-limit failures
assert circuit_breaker.current_state == STATE_OPEN
assert circuit_breaker.fail_counter == 2
assert circuit_breaker.fail_counter >= 2 # At least 2 failures

# Track requests before another query
requests_before = request_count["count"]
Expand Down Expand Up @@ -197,7 +235,10 @@ def mock_conditional_request(*args, **kwargs):
cursor.fetchone()
time.sleep(2)

assert circuit_breaker.current_state == STATE_OPEN
# Wait for circuit to open
assert wait_for_circuit_state(
circuit_breaker, [STATE_OPEN], timeout=5
), f"Circuit didn't open, state: {circuit_breaker.current_state}"

# Wait for reset timeout (5 seconds in test)
time.sleep(6)
Expand All @@ -208,24 +249,20 @@ def mock_conditional_request(*args, **kwargs):
# Execute query to trigger HALF_OPEN state
cursor.execute("SELECT 3")
cursor.fetchone()
time.sleep(1)

# Circuit should be recovering
assert circuit_breaker.current_state in [
STATE_HALF_OPEN,
STATE_CLOSED,
], f"Circuit should be recovering, but is {circuit_breaker.current_state}"
# Wait for circuit to start recovering
assert wait_for_circuit_state(
circuit_breaker, [STATE_HALF_OPEN, STATE_CLOSED], timeout=5
), f"Circuit didn't recover, state: {circuit_breaker.current_state}"

# Execute more queries to fully recover
cursor.execute("SELECT 4")
cursor.fetchone()
time.sleep(1)

current_state = circuit_breaker.current_state
assert current_state in [
STATE_CLOSED,
STATE_HALF_OPEN,
], f"Circuit should recover to CLOSED or HALF_OPEN, got {current_state}"
# Wait for full recovery
assert wait_for_circuit_state(
circuit_breaker, [STATE_CLOSED, STATE_HALF_OPEN], timeout=5
), f"Circuit didn't fully recover, state: {circuit_breaker.current_state}"


if __name__ == "__main__":
Expand Down
36 changes: 32 additions & 4 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,31 @@ class TransactionTestSuite(unittest.TestCase):
"access_token": "tok",
}

def _setup_mock_session_with_http_client(self, mock_session):
"""
Helper to configure a mock session with HTTP client mocks.
This prevents feature flag network requests during Connection initialization.
"""
mock_session.host = "foo"

# Mock HTTP client to prevent feature flag network requests
mock_http_client = Mock()
mock_session.http_client = mock_http_client

# Mock feature flag response to prevent blocking HTTP calls
mock_ff_response = Mock()
mock_ff_response.status = 200
mock_ff_response.data = b'{"flags": [], "ttl_seconds": 900}'
mock_http_client.request.return_value = mock_ff_response

def _create_mock_connection(self, mock_session_class):
"""Helper to create a mocked connection for transaction tests."""
# Mock session
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"
mock_session.get_autocommit.return_value = True

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=False to test actual transaction functionality
Expand Down Expand Up @@ -736,9 +754,7 @@ def test_autocommit_setter_preserves_exception_chain(self, mock_session_class):
conn = self._create_mock_connection(mock_session_class)

mock_cursor = Mock()
original_error = DatabaseError(
"Original error", host_url="test-host"
)
original_error = DatabaseError("Original error", host_url="test-host")
mock_cursor.execute.side_effect = original_error

with patch.object(conn, "cursor", return_value=mock_cursor):
Expand Down Expand Up @@ -927,6 +943,8 @@ def test_fetch_autocommit_from_server_queries_server(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -959,6 +977,8 @@ def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -986,6 +1006,8 @@ def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_cla
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

conn = client.Connection(
Expand Down Expand Up @@ -1015,6 +1037,8 @@ def test_commit_is_noop_when_ignore_transactions_true(self, mock_session_class):
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down Expand Up @@ -1043,6 +1067,8 @@ def test_rollback_raises_not_supported_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand All @@ -1068,6 +1094,8 @@ def test_autocommit_setter_is_noop_when_ignore_transactions_true(
mock_session = Mock()
mock_session.is_open = True
mock_session.guid_hex = "test-session-id"

self._setup_mock_session_with_http_client(mock_session)
mock_session_class.return_value = mock_session

# Create connection with ignore_transactions=True (default)
Expand Down
Loading
Loading