Skip to content

Commit 4e12625

Browse files
authored
[EAGLE-6413]: Add instancetype tests (#230)
Need to detect if and when cloud provider instance type names change or are removed.
1 parent c26d9b9 commit 4e12625

File tree

1 file changed

+270
-0
lines changed

1 file changed

+270
-0
lines changed
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
from typing import Dict, List
2+
3+
import pytest
4+
import requests
5+
6+
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
7+
from clarifai_grpc.grpc.api.status import status_code_pb2
8+
from tests.common import get_channel, grpc_channel, metadata
9+
10+
11+
def fetch_csv_from_github(cloud_provider: str, catalog_version: str = "v7") -> str:
12+
"""
13+
Fetch CSV content directly from GitHub raw URLs instead of cloning the repository.
14+
15+
Args:
16+
cloud_provider: Cloud provider name (aws, gcp, azure, etc.)
17+
catalog_version: Catalog version (default: v7)
18+
19+
Returns:
20+
CSV content as string
21+
22+
Raises:
23+
requests.RequestException: If the HTTP request fails
24+
"""
25+
base_url = "https://raw.githubusercontent.com/skypilot-org/skypilot-catalog/master"
26+
csv_url = f"{base_url}/catalogs/{catalog_version}/{cloud_provider.lower()}/vms.csv"
27+
28+
print(f"Fetching CSV from: {csv_url}")
29+
response = requests.get(csv_url, timeout=30)
30+
response.raise_for_status()
31+
32+
return response.text
33+
34+
35+
def get_instance_types_simple(csv_content: str) -> List[str]:
36+
"""Parse CSV content and extract instance types from the first column."""
37+
instance_types = set()
38+
39+
# Split content into lines and skip header
40+
lines = csv_content.strip().split('\n')[1:]
41+
42+
for line in lines:
43+
if line.strip():
44+
instance_type = line.split(',')[0]
45+
if instance_type:
46+
instance_types.add(instance_type)
47+
48+
return sorted(list(instance_types))
49+
50+
51+
def fetch_skypilot_instance_types(cloud_providers=None):
52+
"""
53+
Fetch instance types from skypilot-catalog repository via direct HTTP requests.
54+
This function:
55+
1. Makes direct HTTP requests to GitHub raw URLs for vms.csv files
56+
2. Parses the CSV content to extract instance type IDs
57+
3. Returns a set of expected instance type IDs
58+
"""
59+
# If no cloud providers specified, use default ones
60+
if cloud_providers is None:
61+
cloud_providers = ['aws', 'gcp', 'azure']
62+
63+
all_instance_types = set()
64+
65+
for provider in cloud_providers:
66+
provider_lower = provider.lower()
67+
68+
try:
69+
csv_content = fetch_csv_from_github(provider_lower)
70+
provider_instance_types = get_instance_types_simple(csv_content)
71+
all_instance_types.update(provider_instance_types)
72+
print(f" Found {len(provider_instance_types)} instance types for {provider_lower}")
73+
except requests.RequestException as e:
74+
print(f"Warning: Could not fetch CSV for {provider_lower}: {e}")
75+
except Exception as e:
76+
print(f"Warning: Error processing CSV for {provider_lower}: {e}")
77+
78+
if not all_instance_types:
79+
raise RuntimeError("No vms.csv files could be fetched for any cloud provider")
80+
81+
print(
82+
f"Successfully fetched {len(all_instance_types)} total instance types from skypilot-catalog"
83+
)
84+
return all_instance_types
85+
86+
87+
def fetch_skypilot_instance_types_by_provider(cloud_provider_id):
88+
"""
89+
Fetch instance types for a specific cloud provider from skypilot-catalog.
90+
Maps Clarifai cloud provider IDs to skypilot-catalog directory names.
91+
"""
92+
try:
93+
# Map Clarifai cloud provider IDs to skypilot-catalog directory names
94+
provider_mapping = {
95+
'aws': 'aws',
96+
'gcp': 'gcp',
97+
'azure': 'azure',
98+
'local': 'local', # if local is supported
99+
}
100+
101+
provider_lower = provider_mapping.get(cloud_provider_id.lower(), cloud_provider_id.lower())
102+
103+
csv_content = fetch_csv_from_github(provider_lower)
104+
instance_types = get_instance_types_simple(csv_content)
105+
print(f" Found {len(instance_types)} instance types for {cloud_provider_id}")
106+
return set(instance_types)
107+
108+
except requests.RequestException as e:
109+
print(f"Warning: Could not fetch CSV for {cloud_provider_id}: {e}")
110+
return set()
111+
except Exception as e:
112+
print(f"Warning: Error processing CSV for {cloud_provider_id}: {e}")
113+
return set()
114+
115+
116+
def get_cloud_providers(stub, metadata_tuple) -> List[resources_pb2.CloudProvider]:
117+
"""Get all available cloud providers."""
118+
try:
119+
response = stub.ListCloudProviders(
120+
service_pb2.ListCloudProvidersRequest(), metadata=metadata_tuple
121+
)
122+
123+
if response.status.code != status_code_pb2.StatusCode.SUCCESS:
124+
pytest.fail(f"Failed to list cloud providers: {response.status.description}")
125+
126+
return response.cloud_providers
127+
except Exception as e:
128+
pytest.fail(f"Error listing cloud providers: {e}")
129+
130+
131+
def get_cloud_regions(
132+
stub, metadata_tuple, cloud_provider: resources_pb2.CloudProvider
133+
) -> List[str]:
134+
"""Get all regions for a specific cloud provider."""
135+
try:
136+
request = service_pb2.ListCloudRegionsRequest(
137+
cloud_provider=resources_pb2.CloudProvider(id=cloud_provider.id)
138+
)
139+
response = stub.ListCloudRegions(request, metadata=metadata_tuple)
140+
141+
if response.status.code != status_code_pb2.StatusCode.SUCCESS:
142+
pytest.fail(
143+
f"Failed to list regions for {cloud_provider.id}: {response.status.description}"
144+
)
145+
146+
# API returns regions as a list of region id strings
147+
return list(response.regions)
148+
except Exception as e:
149+
pytest.fail(f"Error listing regions for {cloud_provider.id}: {e}")
150+
151+
152+
def get_instance_types(
153+
stub, metadata_tuple, cloud_provider: resources_pb2.CloudProvider, region: str
154+
) -> List[str]:
155+
"""Get all instance types for a specific cloud provider and region."""
156+
try:
157+
request = service_pb2.ListInstanceTypesRequest(
158+
cloud_provider=resources_pb2.CloudProvider(id=cloud_provider.id), region=region
159+
)
160+
response = stub.ListInstanceTypes(request, metadata=metadata_tuple)
161+
162+
if response.status.code != status_code_pb2.StatusCode.SUCCESS:
163+
pytest.fail(
164+
f"Failed to list instance types for {cloud_provider.id}/{region}: {response.status.description}"
165+
)
166+
167+
return [instance_type.id for instance_type in response.instance_types]
168+
except Exception as e:
169+
pytest.fail(f"Error listing instance types for {cloud_provider.id}/{region}: {e}")
170+
171+
172+
def collect_all_instance_types(stub, metadata_tuple) -> Dict[str, Dict[str, List[str]]]:
173+
"""
174+
Collect all instance types across all cloud providers and regions.
175+
Returns: {cloud_provider_id: {region: [instance_type_ids]}}
176+
"""
177+
all_instance_types = {}
178+
179+
# Get all cloud providers
180+
cloud_providers = get_cloud_providers(stub, metadata_tuple)
181+
182+
for cloud_provider in cloud_providers:
183+
all_instance_types[cloud_provider.id] = {}
184+
185+
# Get all regions for this cloud provider
186+
regions = get_cloud_regions(stub, metadata_tuple, cloud_provider)
187+
188+
for region in regions:
189+
# Get all instance types for this region
190+
instance_types = get_instance_types(stub, metadata_tuple, cloud_provider, region)
191+
all_instance_types[cloud_provider.id][region] = instance_types
192+
193+
return all_instance_types
194+
195+
196+
UNSUPPORTED_SKYCATALOG_PROVIDERS = {"vultr", "oracle"}
197+
198+
199+
def is_provider_supported(provider_id: str) -> bool:
200+
return provider_id.lower() not in UNSUPPORTED_SKYCATALOG_PROVIDERS
201+
202+
203+
@grpc_channel()
204+
def test_instance_types_exist_and_not_deprecated(channel_key):
205+
"""
206+
Test that all instance types returned by the API exist and are not deprecated.
207+
This test:
208+
1. Gets all cloud providers
209+
2. Gets all regions for each cloud provider
210+
3. Gets all instance types for each region
211+
4. Compares with expected instance types from skypilot-catalog (provider-specific)
212+
5. Raises errors for missing or deprecated instance types
213+
"""
214+
stub = service_pb2_grpc.V2Stub(get_channel(channel_key))
215+
metadata_tuple = metadata(pat=True)
216+
217+
# Collect all instance types from the API
218+
api_instance_types = collect_all_instance_types(stub, metadata_tuple)
219+
220+
# Flatten all API instance types for comparison (only supported providers)
221+
all_api_instance_types = set()
222+
provider_instance_types = {}
223+
224+
for cloud_provider_id, regions in api_instance_types.items():
225+
provider_instance_types[cloud_provider_id] = set()
226+
for _, instance_types in regions.items():
227+
provider_instance_types[cloud_provider_id].update(instance_types)
228+
if is_provider_supported(cloud_provider_id):
229+
all_api_instance_types.update(provider_instance_types[cloud_provider_id])
230+
231+
# Get expected instance types for each supported cloud provider
232+
all_expected_instance_types = set()
233+
provider_expected_types = {}
234+
235+
for cloud_provider_id in api_instance_types.keys():
236+
if not is_provider_supported(cloud_provider_id):
237+
continue
238+
expected_types = fetch_skypilot_instance_types_by_provider(cloud_provider_id)
239+
provider_expected_types[cloud_provider_id] = expected_types
240+
all_expected_instance_types.update(expected_types)
241+
242+
# Check for missing instance types (API returns types not in skypilot-catalog)
243+
missing_in_skypilot = all_api_instance_types - all_expected_instance_types
244+
if missing_in_skypilot:
245+
pytest.fail(
246+
f"Found {len(missing_in_skypilot)} instance types in API that are not in skypilot-catalog: "
247+
f"{sorted(missing_in_skypilot)}"
248+
)
249+
250+
# Log summary for debugging
251+
print("\nInstance Types Summary:")
252+
print(f"Total API instance types (supported providers): {len(all_api_instance_types)}")
253+
print(
254+
f"Cloud providers checked: {[p for p in api_instance_types.keys() if is_provider_supported(p)]}"
255+
)
256+
257+
for cloud_provider_id, regions in api_instance_types.items():
258+
total_for_provider = sum(len(instance_types) for instance_types in regions.values())
259+
if is_provider_supported(cloud_provider_id):
260+
expected_for_provider = len(provider_expected_types.get(cloud_provider_id, set()))
261+
print(
262+
f" {cloud_provider_id}: {total_for_provider} instance types across {len(regions)} regions (expected: {expected_for_provider})"
263+
)
264+
else:
265+
print(f" {cloud_provider_id}: skipped (unsupported by skypilot-catalog)")
266+
267+
# Assert that we have a reasonable number of instance types among supported providers
268+
assert len(all_api_instance_types) > 0, (
269+
"No instance types found in API for supported providers"
270+
)

0 commit comments

Comments
 (0)