Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/python/remote_storage_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,21 @@ The system automatically selects the best available storage backend:
1. Initiator sends memory descriptors to target
2. Target performs storage-to-memory or memory-to-storage operations
3. Data is transferred between initiator and target memory

Remote reads are implemented as a read from storage followed by a network write.

Remote writes are implemented as a read from network following by a storage write.

### Pipelining

To improve performance of the remote storage server, we can pipeline operations to network and storage. This pipelining allows multiple threads to handle each request. However, in order to maintain correctness, the order of network and storage must happen in order for each individual remote storage operation. To do this, we implemented a simple pipelining scheme. This pipeline for remote writes is implemented as a simple read into NIXL descriptors from the network, followed by a write to storage (also through NIXL, but a different plugin). A remote read is similar, just reading into NIXL descriptors from storage and then writing to network.

![Remote Operation Pipelines](storage_pipelines.png)

### Performance Tips

For high-speed storage and network hardware, you may need to tweak performance with a couple of environment variables.

First, for optimal GDS performance, ensure you are using the GDS_MT backend with default concurrency. Additionally, you can use the cufile options described in the [GDS README](https://github.com/ai-dynamo/nixl/blob/main/src/plugins/cuda_gds/README.md). Also a reminder to check that your GDS setup is running true GPU-direct IO and not in compatibility mode.

On the network side, remote reads from VRAM to DRAM can be limited by UCX rail selection. This can be tweaked by setting UCX_MAX_RMA_RAILS=1. However, with larger batch or message sizes, this might limit bandwidth and a higher number of rails might be needed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
268 changes: 218 additions & 50 deletions examples/python/remote_storage_example/nixl_p2p_storage_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Demonstrates peer-to-peer storage transfers using NIXL with initiator and target modes.
"""

import concurrent.futures
import time

import nixl_storage_utils as nsu
Expand All @@ -27,14 +28,20 @@
logger = get_logger(__name__)


def execute_transfer(my_agent, local_descs, remote_descs, remote_name, operation):
handle = my_agent.initialize_xfer(operation, local_descs, remote_descs, remote_name)
def execute_transfer(
my_agent, local_descs, remote_descs, remote_name, operation, use_backends=[]
):
handle = my_agent.initialize_xfer(
operation, local_descs, remote_descs, remote_name, backends=use_backends
)
my_agent.transfer(handle)
nsu.wait_for_transfer(my_agent, handle)
my_agent.release_xfer_handle(handle)


def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name):
def remote_storage_transfer(
my_agent, my_mem_descs, operation, remote_agent_name, iterations
):
"""Initiate remote memory transfer."""
if operation != "READ" and operation != "WRITE":
logger.error("Invalid operation, exiting")
Expand All @@ -45,14 +52,24 @@ def remote_storage_transfer(my_agent, my_mem_descs, operation, remote_agent_name
else:
operation = b"READ"

iterations_str = bytes(f"{iterations:04d}", "utf-8")
# Send the descriptors that you want to read into or write from
logger.info(f"Sending {operation} request to {remote_agent_name}")
logger.info(
"Sending %s request to %s", operation.decode("utf-8"), remote_agent_name
)
test_descs_str = my_agent.get_serialized_descs(my_mem_descs)
my_agent.send_notif(remote_agent_name, operation + test_descs_str)

start_time = time.time()

my_agent.send_notif(remote_agent_name, operation + iterations_str + test_descs_str)

while not my_agent.check_remote_xfer_done(remote_agent_name, b"COMPLETE"):
continue

elapsed = time.time() - start_time

logger.info("Time for %d iterations: %f seconds", iterations, elapsed)


def connect_to_agents(my_agent, agents_file):
target_agents = []
Expand All @@ -66,26 +83,154 @@ def connect_to_agents(my_agent, agents_file):
my_agent.fetch_remote_metadata(parts[0], parts[1], int(parts[2]))

while my_agent.check_remote_metadata(parts[0]) is False:
logger.info(f"Waiting for remote metadata for {parts[0]}...")
logger.info("Waiting for remote metadata for %s...", parts[0])
time.sleep(0.2)

logger.info(f"Remote metadata for {parts[0]} fetched")
logger.info("Remote metadata for %s fetched", parts[0])
else:
logger.error(f"Invalid line in {agents_file}: {line}")
logger.error("Invalid line in %s: %s", agents_file, line)
exit(-1)

logger.info("All remote metadata fetched")

return target_agents


def pipeline_reads(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
n = 0
s = 0
futures = []

while n < iterations or s < iterations:
if s == 0:
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"READ",
)
)
s += 1
continue

if s == iterations:
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"WRITE",
)
)
n += 1
continue

# Do two storage and network in parallel
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"READ",
)
)
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"WRITE",
)
)
s += 1
n += 1

_, not_done = concurrent.futures.wait(
futures, return_when=concurrent.futures.ALL_COMPLETED
)
assert not not_done


def pipeline_writes(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
n = 0
s = 1
futures = []

futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"READ",
)
)
while n < iterations or s < iterations:
if s == iterations:
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"WRITE",
)
)
n += 1
continue

# Do two storage and network in parallel
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
sent_descs,
req_agent,
"READ",
)
)
futures.append(
executor.submit(
execute_transfer,
my_agent,
my_mem_descs,
my_file_descs,
my_agent.name,
"WRITE",
)
)
s += 1
n += 1

_, not_done = concurrent.futures.wait(
futures, return_when=concurrent.futures.ALL_COMPLETED
)
assert not not_done


def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
"""Handle remote memory and storage transfers as target."""
# Wait for initiator to send list of memory descriptors
notifs = my_agent.get_new_notifs()

logger.info("Waiting for a remote transfer request...")

while len(notifs) == 0:
notifs = my_agent.get_new_notifs()

Expand All @@ -101,57 +246,65 @@ def handle_remote_transfer_request(my_agent, my_mem_descs, my_file_descs):
logger.error("Invalid operation, exiting")
exit(-1)

sent_descs = my_agent.deserialize_descs(recv_msg[4:])
iterations = int(recv_msg[4:8])

logger.info("Checking to ensure metadata is loaded...")
while my_agent.check_remote_metadata(req_agent, sent_descs) is False:
continue
logger.info("Performing %s with %d iterations", operation, iterations)

if operation == "READ":
logger.info("Starting READ operation")
sent_descs = my_agent.deserialize_descs(recv_msg[8:])

# Read from file first
execute_transfer(
my_agent, my_mem_descs, my_file_descs, my_agent.name, "READ"
if operation == "READ":
pipeline_reads(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
)
# Send to client
execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "WRITE")

elif operation == "WRITE":
logger.info("Starting WRITE operation")

# Read from client first
execute_transfer(my_agent, my_mem_descs, sent_descs, req_agent, "READ")
# Write to storage
execute_transfer(
my_agent, my_mem_descs, my_file_descs, my_agent.name, "WRITE"
pipeline_writes(
my_agent, req_agent, my_mem_descs, my_file_descs, sent_descs, iterations
)

# Send completion notification to initiator
my_agent.send_notif(req_agent, b"COMPLETE")

logger.info("One transfer test complete.")


def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
def run_client(
my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file, iterations
):
logger.info("Client initialized, ready for local transfer test...")

# For sample purposes, write to and then read from local storage
logger.info("Starting local transfer test...")
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"WRITE",
)
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"READ",
)

start_time = time.time()

for i in range(1, iterations):
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"WRITE",
["GDS_MT"],
)

elapsed = time.time() - start_time

logger.info("Time for %d WRITE iterations: %f seconds", iterations, elapsed)

start_time = time.time()

for i in range(1, iterations):
execute_transfer(
my_agent,
nixl_mem_reg_descs.trim(),
nixl_file_reg_descs.trim(),
my_agent.name,
"READ",
["GDS_MT"],
)

elapsed = time.time() - start_time

logger.info("Time for %d READ iterations: %f seconds", iterations, elapsed)

logger.info("Local transfer test complete")

logger.info("Starting remote transfer test...")
Expand All @@ -161,10 +314,10 @@ def run_client(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, agents_file):
# For sample purposes, write to and then read from each target agent
for target_agent in target_agents:
remote_storage_transfer(
my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent
my_agent, nixl_mem_reg_descs.trim(), "WRITE", target_agent, iterations
)
remote_storage_transfer(
my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent
my_agent, nixl_mem_reg_descs.trim(), "READ", target_agent, iterations
)

logger.info("Remote transfer test complete")
Expand Down Expand Up @@ -199,8 +352,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
type=str,
help="File containing list of target agents (only needed for client)",
)
parser.add_argument(
"--iterations",
type=int,
default=100,
help="Number of iterations for each transfer",
)
args = parser.parse_args()

mem = "DRAM"

if args.role == "client":
mem = "VRAM"

my_agent = nsu.create_agent_with_plugins(args.name, args.port)

(
Expand All @@ -209,15 +373,19 @@ def run_storage_server(my_agent, nixl_mem_reg_descs, nixl_file_reg_descs):
nixl_mem_reg_descs,
nixl_file_reg_descs,
) = nsu.setup_memory_and_files(
my_agent, args.batch_size, args.buf_size, args.fileprefix
my_agent, args.batch_size, args.buf_size, args.fileprefix, mem
)

if args.role == "client":
if not args.agents_file:
parser.error("--agents_file is required when role is client")
try:
run_client(
my_agent, nixl_mem_reg_descs, nixl_file_reg_descs, args.agents_file
my_agent,
nixl_mem_reg_descs,
nixl_file_reg_descs,
args.agents_file,
args.iterations,
)
finally:
nsu.cleanup_resources(
Expand Down
Loading