Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cvxpy = { version = ">=1.1.0", optional = true }
graphviz = { version = ">=0.15", optional = true }
matplotlib = { version = ">=3.0.0", optional = true }
numpy = { version = ">=1.19.0", optional = true }
networkx = { version = ">=2.5", optional = true }

# gateway dependencies
flask = { version = "^2.1.2", optional = true }
Expand All @@ -70,7 +71,7 @@ gcp = ["google-api-python-client", "google-auth", "google-cloud-compute", "googl
ibm = ["ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"]
all = ["boto3", "azure-identity", "azure-mgmt-authorization", "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", "azure-mgmt-storage", "azure-mgmt-subscription", "azure-storage-blob", "google-api-python-client", "google-auth", "google-cloud-compute", "google-cloud-storage", "ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"]
gateway = ["flask", "lz4", "pynacl", "pyopenssl", "werkzeug"]
solver = ["cvxpy", "graphviz", "matplotlib", "numpy"]
solver = ["networkx", "cvxpy", "graphviz", "matplotlib", "numpy"]

[tool.poetry.dev-dependencies]
pytest = ">=6.0.0"
Expand Down
1 change: 1 addition & 0 deletions scripts/requirements-gateway.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ numpy
pandas
pyarrow
typer
networkx
10 changes: 6 additions & 4 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import nacl.secret
import nacl.utils
import typer
import urllib3
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional
Expand Down Expand Up @@ -89,7 +90,6 @@ def _start_gateway(
gateway_server: compute.Server,
gateway_log_dir: Optional[PathLike],
authorize_ssh_pub_key: Optional[str] = None,
e2ee_key_bytes: Optional[str] = None,
):
# map outgoing ports
setup_args = {}
Expand Down Expand Up @@ -119,9 +119,7 @@ def _start_gateway(
gateway_docker_image=gateway_docker_image,
gateway_program_path=str(gateway_program_filename),
gateway_info_path=f"{gateway_log_dir}/gateway_info.json",
e2ee_key_bytes=e2ee_key_bytes, # TODO: remove
use_bbr=self.transfer_config.use_bbr, # TODO: remove
use_compression=self.transfer_config.use_compression,
use_socket_tls=self.transfer_config.use_socket_tls,
)

Expand Down Expand Up @@ -202,6 +200,10 @@ def provision(
# todo: move server.py:start_gateway here
logger.fs.info(f"Using docker image {gateway_docker_image}")
e2ee_key_bytes = nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE)
# save E2EE keys
e2ee_key_file = "e2ee_key"
with open(f"/tmp/{e2ee_key_file}", 'wb') as f:
f.write(e2ee_key_bytes)

# create gateway logging dir
gateway_program_dir = f"{self.log_dir}/programs"
Expand All @@ -218,7 +220,7 @@ def provision(
jobs = []
for node, server in gateway_bound_nodes.items():
jobs.append(
partial(self._start_gateway, gateway_docker_image, node, server, gateway_program_dir, authorize_ssh_pub_key, e2ee_key_bytes)
partial(self._start_gateway, gateway_docker_image, node, server, gateway_program_dir, authorize_ssh_pub_key)
)
logger.fs.debug(f"[Dataplane.provision] Starting gateways on {len(jobs)} servers")
try:
Expand Down
28 changes: 24 additions & 4 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided
from skyplane.planner.planner import (
MulticastDirectPlanner,
DirectPlannerSourceOneSided,
DirectPlannerDestOneSided,
UnicastDirectPlanner,
UnicastILPPlanner,
MulticastILPPlanner,
MulticastMDSTPlanner,
)
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
from skyplane.utils.definitions import tmp_log_dir
Expand Down Expand Up @@ -61,12 +69,23 @@ def __init__(

# planner
self.planning_algorithm = planning_algorithm

if self.planning_algorithm == "direct":
self.planner = MulticastDirectPlanner(self.max_instances, self.n_connections, self.transfer_config)
self.planner = MulticastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "src_one_sided":
self.planner = DirectPlannerSourceOneSided(self.max_instances, self.n_connections, self.transfer_config)
self.planner = DirectPlannerSourceOneSided(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "dst_one_sided":
self.planner = DirectPlannerDestOneSided(self.max_instances, self.n_connections, self.transfer_config)
self.planner = DirectPlannerDestOneSided(self.transfer_config, self.max_instances, self.n_connections)
# TODO: should find some ways to merge direct / Ndirect
self.planner = UnicastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "multi_direct":
self.planner = MulticastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "multi_dst":
self.planner = MulticastMDSTPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "multi_ilp":
self.planning_algorithm = MulticastILPPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "uni_ilp":
self.planning_algorithm = UnicastILPPlanner(self.transfer_config, self.max_instances, self.n_connections)
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

Expand Down Expand Up @@ -185,3 +204,4 @@ def estimate_total_cost(self):

# return size
return total_size * topo.cost_per_gb

13 changes: 9 additions & 4 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _run_multipart_chunk_thread(
src_object = transfer_pair.src_obj
dest_objects = transfer_pair.dst_objs
dest_key = transfer_pair.dst_key
print("dest_key: ", dest_key)
if isinstance(self.src_iface, ObjectStoreInterface):
mime_type = self.src_iface.get_obj_mime_type(src_object.key)
# create multipart upload request per destination
Expand Down Expand Up @@ -283,10 +284,10 @@ def transfer_pair_generator(
dest_provider, dest_region = dst_iface.region_tag().split(":")
try:
dest_key = self.map_object_key_prefix(src_prefix, obj.key, dst_prefix, recursive=recursive)
assert (
dest_key[: len(dst_prefix)] == dst_prefix
), f"Destination key {dest_key} does not start with destination prefix {dst_prefix}"
dest_keys.append(dest_key[len(dst_prefix) :])
# TODO: why is it changed here?
# dest_keys.append(dest_key[len(dst_prefix) :])

dest_keys.append(dest_key)
except exceptions.MissingObjectException as e:
logger.fs.exception(e)
raise e from None
Expand Down Expand Up @@ -508,8 +509,12 @@ def dst_prefixes(self) -> List[str]:
if not hasattr(self, "_dst_prefix"):
if self.transfer_type == "unicast":
self._dst_prefix = [str(parse_path(self.dst_paths[0])[2])]
print("return dst_prefixes for unicast", self._dst_prefix)
else:
for path in self.dst_paths:
print("Parsing result for multicast", parse_path(path))
self._dst_prefix = [str(parse_path(path)[2]) for path in self.dst_paths]
print("return dst_prefixes for multicast", self._dst_prefix)
return self._dst_prefix

@property
Expand Down
2 changes: 1 addition & 1 deletion skyplane/api/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import requests
from rich import print as rprint
from typing import Optional, Dict, List
from typing import Optional, Dict

import skyplane
from skyplane.utils.definitions import tmp_log_dir
Expand Down
11 changes: 3 additions & 8 deletions skyplane/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ class Chunk:
part_number: Optional[int] = None
upload_id: Optional[str] = None # TODO: for broadcast, this is not used

def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int, is_compressed: bool = False):
def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int):
return WireProtocolHeader(
chunk_id=self.chunk_id,
data_len=wire_length,
raw_data_len=raw_wire_length,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

Expand Down Expand Up @@ -99,7 +98,6 @@ class WireProtocolHeader:
chunk_id: str # 128bit UUID
data_len: int # long
raw_data_len: int # long (uncompressed, unecrypted)
is_compressed: bool # char
n_chunks_left_on_socket: int # long

@staticmethod
Expand All @@ -115,8 +113,8 @@ def protocol_version():

@staticmethod
def length_bytes():
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + is_compressed (1) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 1 + 8
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 8

@staticmethod
def from_bytes(data: bytes):
Expand All @@ -130,13 +128,11 @@ def from_bytes(data: bytes):
chunk_id = data[12:28].hex()
chunk_len = int.from_bytes(data[28:36], byteorder="big")
raw_chunk_len = int.from_bytes(data[36:44], byteorder="big")
is_compressed = bool(int.from_bytes(data[44:45], byteorder="big"))
n_chunks_left_on_socket = int.from_bytes(data[45:53], byteorder="big")
return WireProtocolHeader(
chunk_id=chunk_id,
data_len=chunk_len,
raw_data_len=raw_chunk_len,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

Expand All @@ -149,7 +145,6 @@ def to_bytes(self):
out_bytes += chunk_id_bytes
out_bytes += self.data_len.to_bytes(8, byteorder="big")
out_bytes += self.raw_data_len.to_bytes(8, byteorder="big")
out_bytes += self.is_compressed.to_bytes(1, byteorder="big")
out_bytes += self.n_chunks_left_on_socket.to_bytes(8, byteorder="big")
assert len(out_bytes) == WireProtocolHeader.length_bytes(), f"{len(out_bytes)} != {WireProtocolHeader.length_bytes()}"
return out_bytes
Expand Down
14 changes: 2 additions & 12 deletions skyplane/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,6 @@ def start_gateway(
gateway_info_path: str,
log_viewer_port=8888,
use_bbr=False,
use_compression=False,
e2ee_key_bytes=None,
use_socket_tls=False,
):
def check_stderr(tup):
Expand Down Expand Up @@ -338,13 +336,6 @@ def check_stderr(tup):
if self.provider == "aws":
docker_envs["AWS_DEFAULT_REGION"] = self.region_tag.split(":")[1]

# copy E2EE keys
if e2ee_key_bytes is not None:
e2ee_key_file = "e2ee_key"
self.write_file(e2ee_key_bytes, f"/tmp/{e2ee_key_file}")
docker_envs["E2EE_KEY_FILE"] = f"/pkg/data/{e2ee_key_file}"
docker_run_flags += f" -v /tmp/{e2ee_key_file}:/pkg/data/{e2ee_key_file}"

# upload gateway programs and gateway info
gateway_program_file = os.path.basename(gateway_program_path).replace(":", "_")
gateway_info_file = os.path.basename(gateway_info_path).replace(":", "_")
Expand All @@ -359,8 +350,7 @@ def check_stderr(tup):
# update docker flags
docker_run_flags += " " + " ".join(f"--env {k}={v}" for k, v in docker_envs.items())

gateway_daemon_cmd += f" --region {self.region_tag} {'--use-compression' if use_compression else ''}"
gateway_daemon_cmd += f" {'--disable-e2ee' if e2ee_key_bytes is None else ''}"
gateway_daemon_cmd += f" --region {self.region_tag}"
gateway_daemon_cmd += f" {'--disable-tls' if not use_socket_tls else ''}"
escaped_gateway_daemon_cmd = gateway_daemon_cmd.replace('"', '\\"')
docker_launch_cmd = (
Expand All @@ -378,7 +368,7 @@ def check_stderr(tup):
logger.fs.debug(f"{self.uuid()} gateway_api_url = {self.gateway_api_url}")

# wait for gateways to start (check status API)
http_pool = urllib3.PoolManager()
http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=10))

def is_api_ready():
try:
Expand Down
Loading