Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions next_cvat/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,26 @@ class AccessToken(BaseModel):
@classmethod
def from_client_cookies(cls, cookies: Dict, headers: Dict) -> AccessToken:
"""Create an AccessToken from CVAT client cookies and headers."""
# For basic auth, there might not be an Authorization header
# In this case, we'll use the session-based authentication
api_key = headers.get("Authorization", "session-based-auth")

# Extract expires from sessionid cookie (Morsel object)
sessionid_cookie = cookies["sessionid"]
expires_str = sessionid_cookie.get("expires")
if expires_str:
expires_at = parsedate_to_datetime(expires_str)
else:
# Fallback to a reasonable default if no expires is set
from datetime import datetime, timedelta

expires_at = datetime.now() + timedelta(days=14)

return cls(
sessionid=str(cookies["sessionid"]),
csrftoken=str(cookies["csrftoken"]),
api_key=headers["Authorization"],
expires_at=parsedate_to_datetime(cookies["sessionid"]["expires"]),
sessionid=str(sessionid_cookie.value),
csrftoken=str(cookies["csrftoken"].value),
api_key=api_key,
expires_at=expires_at,
)

def serialize(self) -> str:
Expand Down
9 changes: 6 additions & 3 deletions next_cvat/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ def token_cvat_client(self) -> Generator[CVATClient, None, None]:
with make_client(host="app.cvat.ai") as client:
token = AccessToken.deserialize(self.token)

client.api_client.set_default_header(
"Authorization", f"Token {token.api_key}"
)
# Only set Authorization header if we have a real API key (not session-based)
if token.api_key != "session-based-auth":
client.api_client.set_default_header(
"Authorization", f"Token {token.api_key}"
)

client.api_client.cookies["sessionid"] = token.sessionid
client.api_client.cookies["csrftoken"] = token.csrftoken

Expand Down
20 changes: 17 additions & 3 deletions next_cvat/client/job_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def add_polyline_(

return self

def add_polygon_(
self,
polygon: next_cvat.Polygon,
image_name: str,
group: int = 0,
) -> JobAnnotations:
label = self.job.task.project.label(name=polygon.label)

frame = self.job.task.frame(image_name=image_name)

self.annotations["shapes"].append(
polygon.request(frame=frame.id, label_id=label.id, group=group)
)

return self

def add_tag_(
self,
tag: next_cvat.Tag,
Expand All @@ -70,7 +86,5 @@ def add_tag_(
def request(self) -> models.LabeledDataRequest:
request = models.LabeledDataRequest()
request.version = self.annotations["version"]
request.tags = self.annotations["tags"]
request.shapes = self.annotations["shapes"]
request.tracks = self.annotations["tracks"]

return request
38 changes: 33 additions & 5 deletions next_cvat/types/polygon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Tuple

import numpy as np
from cvat_sdk.api_client import models
from PIL import Image, ImageDraw
from pydantic import BaseModel, field_validator

Expand All @@ -11,10 +12,10 @@

class Polygon(BaseModel):
"""A polygon annotation in CVAT.

Polygons are used to define regions in images using a series of connected points.
They can be converted to segmentation masks and support various geometric operations.

Attributes:
label: The label/class name for this polygon
source: The source of this annotation (e.g. "manual", "automatic")
Expand All @@ -23,6 +24,7 @@ class Polygon(BaseModel):
z_order: The z-order/layer of this polygon
attributes: List of additional attributes for this polygon
"""

label: str
source: str
occluded: int
Expand All @@ -33,7 +35,7 @@ class Polygon(BaseModel):
@field_validator("points", mode="before")
def parse_points(cls, v):
"""Parse points from string format if needed.

Handles conversion from CVAT's string format ("x1,y1;x2,y2;...") to list of tuples.
"""
if isinstance(v, str):
Expand Down Expand Up @@ -84,11 +86,37 @@ def translate(self, dx: int, dy: int) -> Polygon:

def polygon(self) -> Polygon:
"""Get this polygon.

This method exists for compatibility with other shape types that can be
converted to polygons.

Returns:
This polygon instance
"""
return self

def request(
self, frame: int, label_id: int, group: int = 0
) -> models.LabeledShapeRequest:
"""Convert the polygon to a CVAT shape format.

Args:
frame: The frame number this polygon appears in
label_id: The ID of the label this polygon is associated with
group: The group ID for this shape (default: 0)

Returns:
LabeledShapeRequest object for CVAT API
"""
return models.LabeledShapeRequest(
type="polygon",
occluded=bool(self.occluded),
points=np.array(self.points).flatten().tolist(),
rotation=0.0,
outside=False,
attributes=[attr.model_dump() for attr in self.attributes],
group=group,
source=self.source,
frame=frame,
label_id=label_id,
)
Loading