Skip to content

Commit 5bca746

Browse files
Mini256NG85
authored andcommitted
feat: add update api for embedding/reranker models (pingcap#677)
part of pingcap#667
1 parent 7f149cb commit 5bca746

File tree

45 files changed

+334
-304
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+334
-304
lines changed

backend/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ ENV PYTHONPATH=/app
2727
COPY . /app/
2828

2929
# Default number of workers
30-
ENV FASTAPI_WORKERS=1
30+
ENV WEB_CONCURRENCY=4
3131

32-
CMD ["sh", "-c", "fastapi run app/api_server.py --host 0.0.0.0 --port 80 --workers ${FASTAPI_WORKERS}"]
32+
CMD ["sh", "-c", "fastapi run app/api_server.py --host 0.0.0.0 --port 80 --workers ${WEB_CONCURRENCY}"]

backend/Makefile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ test:
2525
@uv run pytest -v tests/
2626

2727
dev_backend:
28-
@echo "Running development backend server..."
29-
@uv run fastapi dev app/api_server.py --host 0.0.0.0 --port 5001
28+
@echo "Running backend server in development mode..."
29+
@uv run fastapi dev app/api_server.py --host 127.0.0.1 --port 5001
30+
31+
run_backend:
32+
@echo "Running backend server..."
33+
@uv run fastapi run app/api_server.py --host 0.0.0.0 --port 5001 --workers 4
3034

3135
dev_celery_flower:
3236
@echo "Running Celery Flower..."

backend/app/api/admin_routes/embedding_model/models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def vector_dimension_must_gt_1(cls, v: int) -> int:
2727
class EmbeddingModelUpdate(BaseModel):
2828
name: Optional[str] = None
2929
config: Optional[dict | list] = None
30-
credentials: Optional[Any] = None
31-
is_default: Optional[bool] = False
30+
credentials: Optional[str | dict] = None
3231

3332

3433
class EmbeddingModelItem(BaseModel):
Lines changed: 32 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from typing import List
32

43
from fastapi import APIRouter, Depends
@@ -12,43 +11,29 @@
1211
EmbeddingModelCreate,
1312
)
1413
from app.api.deps import CurrentSuperuserDep, SessionDep
15-
from app.exceptions import EmbeddingModelNotFound, InternalServerError
16-
from app.repositories.embedding_model import embed_model_repo
14+
from app.repositories.embedding_model import embedding_model_repo
1715
from app.rag.embeddings.provider import (
1816
EmbeddingProviderOption,
1917
embedding_provider_options,
2018
)
2119
from app.rag.embeddings.resolver import resolve_embed_model
20+
from app.logger import logger
2221

2322
router = APIRouter()
24-
logger = logging.getLogger(__name__)
2523

2624

27-
@router.get("/admin/embedding-models/provider/options")
25+
@router.get("/admin/embedding-models/providers/options")
2826
def list_embedding_model_provider_options(
2927
user: CurrentSuperuserDep,
3028
) -> List[EmbeddingProviderOption]:
3129
return embedding_provider_options
3230

3331

34-
@router.get("/admin/embedding-models/options", deprecated=True)
35-
def get_embedding_model_options(
36-
user: CurrentSuperuserDep,
37-
) -> List[EmbeddingProviderOption]:
38-
return embedding_provider_options
39-
40-
41-
@router.post("/admin/embedding-models")
42-
def create_embedding_model(
43-
session: SessionDep,
44-
user: CurrentSuperuserDep,
45-
create: EmbeddingModelCreate,
46-
) -> EmbeddingModelDetail:
47-
try:
48-
return embed_model_repo.create(session, create)
49-
except Exception as e:
50-
logger.exception(e)
51-
raise InternalServerError()
32+
@router.get("/admin/embedding-models")
33+
def list_embedding_models(
34+
db_session: SessionDep, user: CurrentSuperuserDep, params: Params = Depends()
35+
) -> Page[EmbeddingModelItem]:
36+
return embedding_model_repo.paginate(db_session, params)
5237

5338

5439
@router.post("/admin/embedding-models/test")
@@ -72,60 +57,50 @@ def test_embedding_model(
7257
success = True
7358
error = ""
7459
except Exception as e:
60+
logger.info(f"Failed to test embedding model: {e}")
7561
success = False
7662
error = str(e)
7763
return EmbeddingModelTestResult(success=success, error=error)
7864

7965

80-
@router.get("/admin/embedding-models")
81-
def list_embedding_models(
82-
session: SessionDep, user: CurrentSuperuserDep, params: Params = Depends()
83-
) -> Page[EmbeddingModelItem]:
84-
return embed_model_repo.paginate(session, params)
66+
@router.post("/admin/embedding-models")
67+
def create_embedding_model(
68+
db_session: SessionDep,
69+
user: CurrentSuperuserDep,
70+
create: EmbeddingModelCreate,
71+
) -> EmbeddingModelDetail:
72+
return embedding_model_repo.create(db_session, create)
8573

8674

8775
@router.get("/admin/embedding-models/{model_id}")
8876
def get_embedding_model_detail(
89-
session: SessionDep, user: CurrentSuperuserDep, model_id: int
77+
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
9078
) -> EmbeddingModelDetail:
91-
try:
92-
return embed_model_repo.must_get(session, model_id)
93-
except EmbeddingModelNotFound as e:
94-
raise e
95-
except Exception as e:
96-
logger.exception(e)
97-
raise InternalServerError()
79+
return embedding_model_repo.must_get(db_session, model_id)
9880

9981

10082
@router.put("/admin/embedding-models/{model_id}")
10183
def update_embedding_model(
102-
session: SessionDep,
84+
db_session: SessionDep,
10385
user: CurrentSuperuserDep,
10486
model_id: int,
10587
update: EmbeddingModelUpdate,
10688
) -> EmbeddingModelDetail:
107-
try:
108-
embed_model = embed_model_repo.must_get(session, model_id)
109-
embed_model_repo.update(session, embed_model, update)
110-
return embed_model
111-
except EmbeddingModelNotFound as e:
112-
raise e
113-
except Exception as e:
114-
logger.exception(e)
115-
raise InternalServerError()
89+
embed_model = embedding_model_repo.must_get(db_session, model_id)
90+
return embedding_model_repo.update(db_session, embed_model, update)
91+
92+
93+
@router.delete("/admin/embedding-models/{model_id}")
94+
def delete_embedding_model(
95+
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
96+
) -> None:
97+
embedding_model = embedding_model_repo.must_get(db_session, model_id)
98+
embedding_model_repo.delete(db_session, embedding_model)
11699

117100

118101
@router.put("/admin/embedding-models/{model_id}/set_default")
119102
def set_default_embedding_model(
120-
session: SessionDep, user: CurrentSuperuserDep, model_id: int
103+
db_session: SessionDep, user: CurrentSuperuserDep, model_id: int
121104
) -> EmbeddingModelDetail:
122-
try:
123-
embed_model = embed_model_repo.must_get(session, model_id)
124-
embed_model_repo.set_default_model(session, model_id)
125-
session.refresh(embed_model)
126-
return embed_model
127-
except EmbeddingModelNotFound as e:
128-
raise e
129-
except Exception as e:
130-
logger.exception(e)
131-
raise InternalServerError()
105+
embed_model = embedding_model_repo.must_get(db_session, model_id)
106+
return embedding_model_repo.set_default(db_session, embed_model)

backend/app/api/admin_routes/knowledge_base/routes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
KnowledgeBase,
3030
)
3131
from app.repositories import (
32-
embed_model_repo,
32+
embedding_model_repo,
3333
llm_repo,
3434
data_source_repo,
3535
knowledge_base_repo,
@@ -77,7 +77,9 @@ def create_knowledge_base(
7777
create.llm_id = llm_repo.must_get_default(session).id
7878

7979
if not create.embedding_model_id:
80-
create.embedding_model_id = embed_model_repo.must_get_default(session).id
80+
create.embedding_model_id = embedding_model_repo.must_get_default(
81+
session
82+
).id
8183

8284
knowledge_base = KnowledgeBase(
8385
name=create.name,

backend/app/api/admin_routes/llm/routes.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
from fastapi_pagination import Page, Params
55
from llama_index.core.base.llms.types import ChatMessage
66
from pydantic import BaseModel
7-
from sqlalchemy import update
87

98
from app.api.deps import CurrentSuperuserDep, SessionDep
109
from app.logger import logger
11-
from app.models import AdminLLM, ChatEngine, KnowledgeBase, LLM, LLMUpdate
10+
from app.models import AdminLLM, LLM, LLMUpdate
1211
from app.rag.llms.provider import LLMProviderOption, llm_provider_options
1312
from app.rag.llms.resolver import resolve_llm
1413
from app.repositories.llm import llm_repo
@@ -17,7 +16,7 @@
1716
router = APIRouter()
1817

1918

20-
@router.get("/admin/llms/provider/options")
19+
@router.get("/admin/llms/providers/options")
2120
def list_llm_provider_options(user: CurrentSuperuserDep) -> List[LLMProviderOption]:
2221
return llm_provider_options
2322

@@ -31,15 +30,6 @@ def list_llms(
3130
return llm_repo.paginate(db_session, params)
3231

3332

34-
@router.post("/admin/llms")
35-
def create_llm(
36-
db_session: SessionDep,
37-
user: CurrentSuperuserDep,
38-
llm: LLM,
39-
) -> AdminLLM:
40-
return llm_repo.create(db_session, llm)
41-
42-
4333
class LLMTestResult(BaseModel):
4434
success: bool
4535
error: str = ""
@@ -72,12 +62,21 @@ def test_llm(
7262
success = True
7363
error = ""
7464
except Exception as e:
75-
logger.error(f"Failed to test LLM: {e}")
65+
logger.info(f"Failed to test LLM: {e}")
7666
success = False
7767
error = str(e)
7868
return LLMTestResult(success=success, error=error)
7969

8070

71+
@router.post("/admin/llms")
72+
def create_llm(
73+
db_session: SessionDep,
74+
user: CurrentSuperuserDep,
75+
llm: LLM,
76+
) -> AdminLLM:
77+
return llm_repo.create(db_session, llm)
78+
79+
8180
@router.get("/admin/llms/{llm_id}")
8281
def get_llm(
8382
db_session: SessionDep,
@@ -103,26 +102,9 @@ def delete_llm(
103102
db_session: SessionDep,
104103
user: CurrentSuperuserDep,
105104
llm_id: int,
106-
) -> AdminLLM:
105+
) -> None:
107106
llm = llm_repo.must_get(db_session, llm_id)
108-
109-
# TODO: Support to specify a new LLM to replace the current LLM.
110-
db_session.exec(
111-
update(ChatEngine).where(ChatEngine.llm_id == llm_id).values(llm_id=None)
112-
)
113-
db_session.exec(
114-
update(ChatEngine)
115-
.where(ChatEngine.fast_llm_id == llm_id)
116-
.values(fast_llm_id=None)
117-
)
118-
db_session.exec(
119-
update(KnowledgeBase).where(KnowledgeBase.llm_id == llm_id).values(llm_id=None)
120-
)
121-
122-
# TODO: Should using soft deletion.
123-
db_session.delete(llm)
124-
db_session.commit()
125-
return llm
107+
llm_repo.delete(db_session, llm)
126108

127109

128110
@router.put("/admin/llms/{llm_id}/set_default")

0 commit comments

Comments
 (0)