diff --git a/frontend/src/api/apiClient.ts b/frontend/src/api/apiClient.ts index 26a6ab8a..857c1520 100644 --- a/frontend/src/api/apiClient.ts +++ b/frontend/src/api/apiClient.ts @@ -3,7 +3,9 @@ import { FormValues } from "../pages/Feedback/FeedbackForm"; import { Conversation } from "../components/Header/Chat"; const baseURL = import.meta.env.VITE_API_BASE_URL; -export const api = axios.create({ +export const publicApi = axios.create({ baseURL }); + +export const adminApi = axios.create({ baseURL, headers: { Authorization: `JWT ${localStorage.getItem("access")}`, @@ -11,7 +13,7 @@ export const api = axios.create({ }); // Request interceptor to set the Authorization header -api.interceptors.request.use( +adminApi.interceptors.request.use( (configuration) => { const token = localStorage.getItem("access"); if (token) { @@ -42,9 +44,14 @@ const handleSubmitFeedback = async ( } }; -const handleSendDrugSummary = async (message: FormValues["message"], guid: string) => { +const handleSendDrugSummary = async ( + message: FormValues["message"], + guid: string, +) => { try { - const endpoint = guid ? `/v1/api/embeddings/ask_embeddings?guid=${guid}` : '/v1/api/embeddings/ask_embeddings'; + const endpoint = guid + ? `/v1/api/embeddings/ask_embeddings?guid=${guid}` + : "/v1/api/embeddings/ask_embeddings"; const response = await api.post(endpoint, { message, }); @@ -58,7 +65,9 @@ const handleSendDrugSummary = async (message: FormValues["message"], guid: strin const handleRuleExtraction = async (guid: string) => { try { - const response = await api.get(`/v1/api/rule_extraction_openai?guid=${guid}`); + const response = await api.get( + `/v1/api/rule_extraction_openai?guid=${guid}`, + ); // console.log("Rule extraction response:", JSON.stringify(response.data, null, 2)); return response.data; } catch (error) { @@ -67,7 +76,10 @@ const handleRuleExtraction = async (guid: string) => { } }; -const fetchRiskDataWithSources = async (medication: string, source: "include" | "diagnosis" | "diagnosis_depressed" = "include") => { +const fetchRiskDataWithSources = async ( + medication: string, + source: "include" | "diagnosis" | "diagnosis_depressed" = "include", +) => { try { const response = await api.post(`/v1/api/riskWithSources`, { drug: medication, @@ -90,7 +102,7 @@ interface StreamCallbacks { const handleSendDrugSummaryStream = async ( message: string, guid: string, - callbacks: StreamCallbacks + callbacks: StreamCallbacks, ): Promise => { const token = localStorage.getItem("access"); const endpoint = `/v1/api/embeddings/ask_embeddings?stream=true${ @@ -165,12 +177,18 @@ const handleSendDrugSummaryStream = async ( } } } catch (parseError) { - console.error("Failed to parse SSE data:", parseError, "Raw line:", line); + console.error( + "Failed to parse SSE data:", + parseError, + "Raw line:", + line, + ); } } } } catch (error) { - const errorMessage = error instanceof Error ? error.message : "Unknown error"; + const errorMessage = + error instanceof Error ? error.message : "Unknown error"; console.error("Error in stream:", errorMessage); callbacks.onError?.(errorMessage); throw error; @@ -186,7 +204,7 @@ const handleSendDrugSummaryStreamLegacy = async ( return handleSendDrugSummaryStream(message, guid, { onContent: onChunk, onError: (error) => console.error("Stream error:", error), - onComplete: () => console.log("Stream completed") + onComplete: () => console.log("Stream completed"), }); }; @@ -255,11 +273,16 @@ const deleteConversation = async (id: string) => { const updateConversationTitle = async ( id: Conversation["id"], newTitle: Conversation["title"], -): Promise<{status: string, title: Conversation["title"]} | {error: string}> => { +): Promise< + { status: string; title: Conversation["title"] } | { error: string } +> => { try { - const response = await api.patch(`/chatgpt/conversations/${id}/update_title/`, { - title: newTitle, - }); + const response = await api.patch( + `/chatgpt/conversations/${id}/update_title/`, + { + title: newTitle, + }, + ); return response.data; } catch (error) { console.error("Error(s) during getConversation: ", error); @@ -268,9 +291,12 @@ const updateConversationTitle = async ( }; // Assistant API functions -const sendAssistantMessage = async (message: string, previousResponseId?: string) => { +const sendAssistantMessage = async ( + message: string, + previousResponseId?: string, +) => { try { - const response = await api.post(`/v1/api/assistant`, { + const response = await publicApi.post(`/v1/api/assistant`, { message, previous_response_id: previousResponseId, }); @@ -294,5 +320,5 @@ export { handleSendDrugSummaryStream, handleSendDrugSummaryStreamLegacy, fetchRiskDataWithSources, - sendAssistantMessage -}; \ No newline at end of file + sendAssistantMessage, +}; diff --git a/server/api/services/embedding_services.py b/server/api/services/embedding_services.py index 6fd34d35..b50dd750 100644 --- a/server/api/services/embedding_services.py +++ b/server/api/services/embedding_services.py @@ -1,5 +1,4 @@ -# services/embedding_services.py - +from django.db.models import Q from pgvector.django import L2Distance from .sentencetTransformer_model import TransformerModel @@ -39,17 +38,29 @@ def get_closest_embeddings( - file_id: GUID of the source file """ - # transformerModel = TransformerModel.get_instance().model embedding_message = transformerModel.encode(message_data) - # Start building the query based on the message's embedding - closest_embeddings_query = ( - Embeddings.objects.filter(upload_file__uploaded_by=user) - .annotate( - distance=L2Distance("embedding_sentence_transformers", embedding_message) + + if user.is_authenticated: + # User sees their own files + files uploaded by superusers + closest_embeddings_query = ( + Embeddings.objects.filter( + Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True) + ) + .annotate( + distance=L2Distance("embedding_sentence_transformers", embedding_message) + ) + .order_by("distance") + ) + else: + # Unauthenticated users only see superuser-uploaded files + closest_embeddings_query = ( + Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True) + .annotate( + distance=L2Distance("embedding_sentence_transformers", embedding_message) + ) + .order_by("distance") ) - .order_by("distance") - ) # Filter by GUID if provided, otherwise filter by document name if provided if guid: diff --git a/server/api/views/assistant/views.py b/server/api/views/assistant/views.py index 32089c58..67ba8a56 100644 --- a/server/api/views/assistant/views.py +++ b/server/api/views/assistant/views.py @@ -7,7 +7,7 @@ from rest_framework.views import APIView from rest_framework.response import Response from rest_framework import status -from rest_framework.permissions import IsAuthenticated +from rest_framework.permissions import AllowAny from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt @@ -111,7 +111,7 @@ def invoke_functions_from_response( @method_decorator(csrf_exempt, name="dispatch") class Assistant(APIView): - permission_classes = [IsAuthenticated] + permission_classes = [AllowAny] def post(self, request): try: