Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions frontend/src/api/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ import { FormValues } from "../pages/Feedback/FeedbackForm";
import { Conversation } from "../components/Header/Chat";
const baseURL = import.meta.env.VITE_API_BASE_URL;

export const api = axios.create({
export const publicApi = axios.create({ baseURL });

export const adminApi = axios.create({
baseURL,
headers: {
Authorization: `JWT ${localStorage.getItem("access")}`,
},
});

// Request interceptor to set the Authorization header
api.interceptors.request.use(
adminApi.interceptors.request.use(
(configuration) => {
const token = localStorage.getItem("access");
if (token) {
Expand Down Expand Up @@ -42,9 +44,14 @@ const handleSubmitFeedback = async (
}
};

const handleSendDrugSummary = async (message: FormValues["message"], guid: string) => {
const handleSendDrugSummary = async (
message: FormValues["message"],
guid: string,
) => {
try {
const endpoint = guid ? `/v1/api/embeddings/ask_embeddings?guid=${guid}` : '/v1/api/embeddings/ask_embeddings';
const endpoint = guid
? `/v1/api/embeddings/ask_embeddings?guid=${guid}`
: "/v1/api/embeddings/ask_embeddings";
const response = await api.post(endpoint, {
message,
});
Expand All @@ -58,7 +65,9 @@ const handleSendDrugSummary = async (message: FormValues["message"], guid: strin

const handleRuleExtraction = async (guid: string) => {
try {
const response = await api.get(`/v1/api/rule_extraction_openai?guid=${guid}`);
const response = await api.get(
`/v1/api/rule_extraction_openai?guid=${guid}`,
);
// console.log("Rule extraction response:", JSON.stringify(response.data, null, 2));
return response.data;
} catch (error) {
Expand All @@ -67,7 +76,10 @@ const handleRuleExtraction = async (guid: string) => {
}
};

const fetchRiskDataWithSources = async (medication: string, source: "include" | "diagnosis" | "diagnosis_depressed" = "include") => {
const fetchRiskDataWithSources = async (
medication: string,
source: "include" | "diagnosis" | "diagnosis_depressed" = "include",
) => {
try {
const response = await api.post(`/v1/api/riskWithSources`, {
drug: medication,
Expand All @@ -90,7 +102,7 @@ interface StreamCallbacks {
const handleSendDrugSummaryStream = async (
message: string,
guid: string,
callbacks: StreamCallbacks
callbacks: StreamCallbacks,
): Promise<void> => {
const token = localStorage.getItem("access");
const endpoint = `/v1/api/embeddings/ask_embeddings?stream=true${
Expand Down Expand Up @@ -165,12 +177,18 @@ const handleSendDrugSummaryStream = async (
}
}
} catch (parseError) {
console.error("Failed to parse SSE data:", parseError, "Raw line:", line);
console.error(
"Failed to parse SSE data:",
parseError,
"Raw line:",
line,
);
}
}
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : "Unknown error";
const errorMessage =
error instanceof Error ? error.message : "Unknown error";
console.error("Error in stream:", errorMessage);
callbacks.onError?.(errorMessage);
throw error;
Expand All @@ -186,7 +204,7 @@ const handleSendDrugSummaryStreamLegacy = async (
return handleSendDrugSummaryStream(message, guid, {
onContent: onChunk,
onError: (error) => console.error("Stream error:", error),
onComplete: () => console.log("Stream completed")
onComplete: () => console.log("Stream completed"),
});
};

Expand Down Expand Up @@ -255,11 +273,16 @@ const deleteConversation = async (id: string) => {
const updateConversationTitle = async (
id: Conversation["id"],
newTitle: Conversation["title"],
): Promise<{status: string, title: Conversation["title"]} | {error: string}> => {
): Promise<
{ status: string; title: Conversation["title"] } | { error: string }
> => {
try {
const response = await api.patch(`/chatgpt/conversations/${id}/update_title/`, {
title: newTitle,
});
const response = await api.patch(
`/chatgpt/conversations/${id}/update_title/`,
{
title: newTitle,
},
);
return response.data;
} catch (error) {
console.error("Error(s) during getConversation: ", error);
Expand All @@ -268,9 +291,12 @@ const updateConversationTitle = async (
};

// Assistant API functions
const sendAssistantMessage = async (message: string, previousResponseId?: string) => {
const sendAssistantMessage = async (
message: string,
previousResponseId?: string,
) => {
try {
const response = await api.post(`/v1/api/assistant`, {
const response = await publicApi.post(`/v1/api/assistant`, {
message,
previous_response_id: previousResponseId,
});
Expand All @@ -294,5 +320,5 @@ export {
handleSendDrugSummaryStream,
handleSendDrugSummaryStreamLegacy,
fetchRiskDataWithSources,
sendAssistantMessage
};
sendAssistantMessage,
};
6 changes: 3 additions & 3 deletions frontend/src/services/actions/auth.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ export const login =

export const logout = () => async (dispatch: AppDispatch) => {
// Clear chat conversation data on logout for security
sessionStorage.removeItem('currentConversation');
sessionStorage.removeItem("currentConversation");

dispatch({
type: LOGOUT,
});
Expand Down Expand Up @@ -207,7 +207,7 @@ export const reset_password_confirm =
uid: string,
token: string,
new_password: string,
re_new_password: string
re_new_password: string,
): ThunkType =>
async (dispatch: AppDispatch) => {
const config = {
Expand Down
31 changes: 21 additions & 10 deletions server/api/services/embedding_services.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# services/embedding_services.py

from django.db.models import Q
from pgvector.django import L2Distance

from .sentencetTransformer_model import TransformerModel
Expand Down Expand Up @@ -39,17 +38,29 @@ def get_closest_embeddings(
- file_id: GUID of the source file
"""

#
transformerModel = TransformerModel.get_instance().model
embedding_message = transformerModel.encode(message_data)
# Start building the query based on the message's embedding
closest_embeddings_query = (
Embeddings.objects.filter(upload_file__uploaded_by=user)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)

if user and user.is_authenticated:
# User sees their own files + files uploaded by superusers
closest_embeddings_query = (
Embeddings.objects.filter(
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
)
else:
# Unauthenticated users only see superuser-uploaded files
closest_embeddings_query = (
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
.annotate(
distance=L2Distance("embedding_sentence_transformers", embedding_message)
)
.order_by("distance")
)
.order_by("distance")
)

# Filter by GUID if provided, otherwise filter by document name if provided
if guid:
Expand Down
4 changes: 2 additions & 2 deletions server/api/views/assistant/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import IsAuthenticated
from rest_framework.permissions import AllowAny
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt

Expand Down Expand Up @@ -111,7 +111,7 @@ def invoke_functions_from_response(

@method_decorator(csrf_exempt, name="dispatch")
class Assistant(APIView):
permission_classes = [IsAuthenticated]
permission_classes = [AllowAny]

def post(self, request):
try:
Expand Down