Skip to content

Commit bae4e1c

Browse files
authored
Merge pull request #399 from sahilds1/393-make-the-chatbot-public
[#393] Allow unauthenticated users to use the chatbot
2 parents 4161ecc + ffe1d57 commit bae4e1c

File tree

3 files changed

+67
-30
lines changed

3 files changed

+67
-30
lines changed

frontend/src/api/apiClient.ts

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@ import { FormValues } from "../pages/Feedback/FeedbackForm";
33
import { Conversation } from "../components/Header/Chat";
44
const baseURL = import.meta.env.VITE_API_BASE_URL;
55

6-
export const api = axios.create({
6+
export const publicApi = axios.create({ baseURL });
7+
8+
export const adminApi = axios.create({
79
baseURL,
810
headers: {
911
Authorization: `JWT ${localStorage.getItem("access")}`,
1012
},
1113
});
1214

1315
// Request interceptor to set the Authorization header
14-
api.interceptors.request.use(
16+
adminApi.interceptors.request.use(
1517
(configuration) => {
1618
const token = localStorage.getItem("access");
1719
if (token) {
@@ -42,9 +44,14 @@ const handleSubmitFeedback = async (
4244
}
4345
};
4446

45-
const handleSendDrugSummary = async (message: FormValues["message"], guid: string) => {
47+
const handleSendDrugSummary = async (
48+
message: FormValues["message"],
49+
guid: string,
50+
) => {
4651
try {
47-
const endpoint = guid ? `/v1/api/embeddings/ask_embeddings?guid=${guid}` : '/v1/api/embeddings/ask_embeddings';
52+
const endpoint = guid
53+
? `/v1/api/embeddings/ask_embeddings?guid=${guid}`
54+
: "/v1/api/embeddings/ask_embeddings";
4855
const response = await api.post(endpoint, {
4956
message,
5057
});
@@ -58,7 +65,9 @@ const handleSendDrugSummary = async (message: FormValues["message"], guid: strin
5865

5966
const handleRuleExtraction = async (guid: string) => {
6067
try {
61-
const response = await api.get(`/v1/api/rule_extraction_openai?guid=${guid}`);
68+
const response = await api.get(
69+
`/v1/api/rule_extraction_openai?guid=${guid}`,
70+
);
6271
// console.log("Rule extraction response:", JSON.stringify(response.data, null, 2));
6372
return response.data;
6473
} catch (error) {
@@ -67,7 +76,10 @@ const handleRuleExtraction = async (guid: string) => {
6776
}
6877
};
6978

70-
const fetchRiskDataWithSources = async (medication: string, source: "include" | "diagnosis" | "diagnosis_depressed" = "include") => {
79+
const fetchRiskDataWithSources = async (
80+
medication: string,
81+
source: "include" | "diagnosis" | "diagnosis_depressed" = "include",
82+
) => {
7183
try {
7284
const response = await api.post(`/v1/api/riskWithSources`, {
7385
drug: medication,
@@ -90,7 +102,7 @@ interface StreamCallbacks {
90102
const handleSendDrugSummaryStream = async (
91103
message: string,
92104
guid: string,
93-
callbacks: StreamCallbacks
105+
callbacks: StreamCallbacks,
94106
): Promise<void> => {
95107
const token = localStorage.getItem("access");
96108
const endpoint = `/v1/api/embeddings/ask_embeddings?stream=true${
@@ -165,12 +177,18 @@ const handleSendDrugSummaryStream = async (
165177
}
166178
}
167179
} catch (parseError) {
168-
console.error("Failed to parse SSE data:", parseError, "Raw line:", line);
180+
console.error(
181+
"Failed to parse SSE data:",
182+
parseError,
183+
"Raw line:",
184+
line,
185+
);
169186
}
170187
}
171188
}
172189
} catch (error) {
173-
const errorMessage = error instanceof Error ? error.message : "Unknown error";
190+
const errorMessage =
191+
error instanceof Error ? error.message : "Unknown error";
174192
console.error("Error in stream:", errorMessage);
175193
callbacks.onError?.(errorMessage);
176194
throw error;
@@ -186,7 +204,7 @@ const handleSendDrugSummaryStreamLegacy = async (
186204
return handleSendDrugSummaryStream(message, guid, {
187205
onContent: onChunk,
188206
onError: (error) => console.error("Stream error:", error),
189-
onComplete: () => console.log("Stream completed")
207+
onComplete: () => console.log("Stream completed"),
190208
});
191209
};
192210

@@ -255,11 +273,16 @@ const deleteConversation = async (id: string) => {
255273
const updateConversationTitle = async (
256274
id: Conversation["id"],
257275
newTitle: Conversation["title"],
258-
): Promise<{status: string, title: Conversation["title"]} | {error: string}> => {
276+
): Promise<
277+
{ status: string; title: Conversation["title"] } | { error: string }
278+
> => {
259279
try {
260-
const response = await api.patch(`/chatgpt/conversations/${id}/update_title/`, {
261-
title: newTitle,
262-
});
280+
const response = await api.patch(
281+
`/chatgpt/conversations/${id}/update_title/`,
282+
{
283+
title: newTitle,
284+
},
285+
);
263286
return response.data;
264287
} catch (error) {
265288
console.error("Error(s) during getConversation: ", error);
@@ -268,9 +291,12 @@ const updateConversationTitle = async (
268291
};
269292

270293
// Assistant API functions
271-
const sendAssistantMessage = async (message: string, previousResponseId?: string) => {
294+
const sendAssistantMessage = async (
295+
message: string,
296+
previousResponseId?: string,
297+
) => {
272298
try {
273-
const response = await api.post(`/v1/api/assistant`, {
299+
const response = await publicApi.post(`/v1/api/assistant`, {
274300
message,
275301
previous_response_id: previousResponseId,
276302
});
@@ -294,5 +320,5 @@ export {
294320
handleSendDrugSummaryStream,
295321
handleSendDrugSummaryStreamLegacy,
296322
fetchRiskDataWithSources,
297-
sendAssistantMessage
298-
};
323+
sendAssistantMessage,
324+
};

server/api/services/embedding_services.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# services/embedding_services.py
2-
1+
from django.db.models import Q
32
from pgvector.django import L2Distance
43

54
from .sentencetTransformer_model import TransformerModel
@@ -39,17 +38,29 @@ def get_closest_embeddings(
3938
- file_id: GUID of the source file
4039
"""
4140

42-
#
4341
transformerModel = TransformerModel.get_instance().model
4442
embedding_message = transformerModel.encode(message_data)
45-
# Start building the query based on the message's embedding
46-
closest_embeddings_query = (
47-
Embeddings.objects.filter(upload_file__uploaded_by=user)
48-
.annotate(
49-
distance=L2Distance("embedding_sentence_transformers", embedding_message)
43+
44+
if user.is_authenticated:
45+
# User sees their own files + files uploaded by superusers
46+
closest_embeddings_query = (
47+
Embeddings.objects.filter(
48+
Q(upload_file__uploaded_by=user) | Q(upload_file__uploaded_by__is_superuser=True)
49+
)
50+
.annotate(
51+
distance=L2Distance("embedding_sentence_transformers", embedding_message)
52+
)
53+
.order_by("distance")
54+
)
55+
else:
56+
# Unauthenticated users only see superuser-uploaded files
57+
closest_embeddings_query = (
58+
Embeddings.objects.filter(upload_file__uploaded_by__is_superuser=True)
59+
.annotate(
60+
distance=L2Distance("embedding_sentence_transformers", embedding_message)
61+
)
62+
.order_by("distance")
5063
)
51-
.order_by("distance")
52-
)
5364

5465
# Filter by GUID if provided, otherwise filter by document name if provided
5566
if guid:

server/api/views/assistant/views.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from rest_framework.views import APIView
88
from rest_framework.response import Response
99
from rest_framework import status
10-
from rest_framework.permissions import IsAuthenticated
10+
from rest_framework.permissions import AllowAny
1111
from django.utils.decorators import method_decorator
1212
from django.views.decorators.csrf import csrf_exempt
1313

@@ -111,7 +111,7 @@ def invoke_functions_from_response(
111111

112112
@method_decorator(csrf_exempt, name="dispatch")
113113
class Assistant(APIView):
114-
permission_classes = [IsAuthenticated]
114+
permission_classes = [AllowAny]
115115

116116
def post(self, request):
117117
try:

0 commit comments

Comments
 (0)