Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion medcat-trainer/webapp/api/api/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from medcat.vocab import Vocab
from medcat.utils.legacy.convert_cdb import get_cdb_from_old

from api.models import ConceptDB
from api.models import ConceptDB, ModelPack

"""
Module level caches for CDBs, Vocabs and CAT instances.
Expand Down Expand Up @@ -163,6 +163,22 @@ def get_medcat_from_model_pack(project, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT:
return cat


def get_medcat_from_model_pack_id(modelpack_id: int, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT:
"""
Load (and cache) a MedCAT model pack directly from a ModelPack id.
"""
cat_id = f'mp{modelpack_id}'
if cat_id in cat_map:
return cat_map[cat_id]

model_pack_obj = ModelPack.objects.get(id=modelpack_id)
logger.info('Loading model pack from:%s', model_pack_obj.model_pack.path)
cat = CAT.load_model_pack(model_pack_obj.model_pack.path)
cat_map[cat_id] = cat
_clear_models(cat_map=cat_map)
return cat


def get_medcat(project,
cdb_map: Dict[str, CDB]=CDB_MAP,
vocab_map: Dict[str, Vocab]=VOCAB_MAP,
Expand Down Expand Up @@ -204,6 +220,16 @@ def clear_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP):
del cat_map[cat_id]


def is_model_pack_loaded(modelpack_id: int, cat_map: Dict[str, CAT]=CAT_MAP) -> bool:
return f'mp{modelpack_id}' in cat_map


def clear_cached_medcat_by_model_pack_id(modelpack_id: int, cat_map: Dict[str, CAT]=CAT_MAP) -> None:
cat_id = f'mp{modelpack_id}'
if cat_id in cat_map:
del cat_map[cat_id]


def get_cached_cdb(cdb_id: str, cdb_map: Dict[str, CDB]=CDB_MAP) -> CDB:
from api.utils import clear_cdb_cnf_addons
if cdb_id not in cdb_map:
Expand Down
95 changes: 84 additions & 11 deletions medcat-trainer/webapp/api/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import_concepts_from_cdb
from .data_utils import upload_projects_export
from .metrics import calculate_metrics
from .model_cache import get_medcat, get_cached_cdb, VOCAB_MAP, clear_cached_medcat, CAT_MAP, CDB_MAP, is_model_loaded
from .model_cache import get_medcat, get_medcat_from_model_pack_id, get_cached_cdb, VOCAB_MAP, clear_cached_medcat, clear_cached_medcat_by_model_pack_id, is_model_pack_loaded, CAT_MAP, CDB_MAP, is_model_loaded
from .permissions import *
from .serializers import *
from .solr_utils import collections_available, search_collection, ensure_concept_searchable
Expand Down Expand Up @@ -637,32 +637,87 @@

@api_view(http_method_names=['POST'])
def annotate_text(request):
p_id = request.data['project_id']
message = request.data['message']
cuis = request.data['cuis']
if message is None or p_id is None:
return HttpResponseBadRequest('No message to annotate')
message = request.data.get('message')
cuis = request.data.get('cuis', [])
p_id = request.data.get('project_id')
modelpack_id = request.data.get('modelpack_id')
include_sub_concepts = request.data.get('include_sub_concepts', False)

project = ProjectAnnotateEntities.objects.get(id=p_id)
if message is None or (p_id is None and modelpack_id is None):
return HttpResponseBadRequest('No message to annotate')

cat = get_medcat(project=project)
cat.config.components.linking.filters.cuis = set(cuis)
if modelpack_id is not None:
try:
cat = get_medcat_from_model_pack_id(int(modelpack_id))
except (ValueError, TypeError):
logger.warning(f'Invalid modelpack_id received for project:{p_id}')
return HttpResponseBadRequest('Invalid modelpack_id for project')
except ModelPack.DoesNotExist:
logger.warning(f'ModelPack does not exist received for project:{p_id}')
return HttpResponseBadRequest('ModelPack does not exist for project')
else:
project = ProjectAnnotateEntities.objects.get(id=p_id)
cat = get_medcat(project=project)

# Normalise cuis to a set[str]
if isinstance(cuis, str):
cuis_set = {c.strip() for c in cuis.split(',') if c.strip()}
elif isinstance(cuis, (list, tuple, set)):
cuis_set = {str(c).strip() for c in cuis if str(c).strip()}
else:
cuis_set = set()

# Expand CUIs to include sub-concepts if requested
if include_sub_concepts and cuis_set and cat.cdb:
expanded_cuis = set(cuis_set)
for parent_cui in cuis_set:
try:
child_cuis = get_all_ch(parent_cui, cat.cdb)
expanded_cuis.update(child_cuis)
except Exception as e:
logger.warning(f'Failed to get children for CUI {parent_cui}: {e}')
cuis_set = expanded_cuis

curr_cuis = cat.config.components.linking.filters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a context manager would be better here?
I.e using https://docs.cogstack.org/projects/nlp/en/latest/autoapi/medcat/utils/config_utils/index.html#medcat.utils.config_utils.temp_changed_config
something like:

with temp_changed_config(cat.config.components.linking, 'filters', cuis_set):
    spacy_doc = cat(message)

cat.config.components.linking.filters.cuis = cuis_set
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this was always the case, but could this potentially cause issues? I.e since the same model packs instances are (potentially) used across multiple projects at once, I could see this creating unexpected filtering.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point, fixed

spacy_doc = cat(message)
cat.config.components.linking.filters = curr_cuis

ents = []
anno_tkns = []
for ent in spacy_doc.linked_ents:
cnt = Entity.objects.filter(label=ent.cui).count()
inc_ent = all(tkn not in anno_tkns for tkn in ent)
if inc_ent and cnt != 0:
meta_annotations = []
if 'meta_cat_meta_anns' in ent.get_available_addon_paths():
meta_anns = ent.get_addon_data('meta_cat_meta_anns')
for meta_ann_task, pred in meta_anns.items():
# Extract value and confidence from pred
# pred can be a dict, object, or string
if isinstance(pred, dict):
pred_value = pred.get('value', str(pred))
pred_confidence = pred.get('confidence', None)
elif hasattr(pred, 'value'):
pred_value = pred.value
pred_confidence = getattr(pred, 'confidence', None)
else:
pred_value = str(pred)
pred_confidence = None
meta_annotations.append({
'task': meta_ann_task,
'value': pred_value,
'confidence': pred_confidence
})
anno_tkns.extend([tkn for tkn in ent])
entity = Entity.objects.get(label=ent.cui)
ents.append({
'entity': entity.id,
'value': ent.base.text,
'start_ind': ent.base.start_char_index,
'end_ind': ent.base.end_char_index,
'acc': ent.context_similarity
'acc': ent.context_similarity,
'meta_annotations': meta_annotations
})

ents.sort(key=lambda e: e['start_ind'])
Expand Down Expand Up @@ -714,7 +769,7 @@

@api_view(http_method_names=['GET'])
def search_solr(request):
query = request.GET.get('search')

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information
flows to this location and may be exposed to an external user.
cdbs = request.GET.get('cdbs').split(',')
return search_collection(cdbs, query)

Expand Down Expand Up @@ -752,7 +807,7 @@


@api_view(http_method_names=['GET', 'DELETE'])
def cache_model(request, project_id):
def cache_project_model(request, project_id):
try:
project = ProjectAnnotateEntities.objects.get(id=project_id)
is_loaded = is_model_loaded(project)
Expand All @@ -772,6 +827,24 @@
return Response({'message': f'{str(e)}'}, 500)


@api_view(http_method_names=['GET', 'DELETE'])
def cache_modelpack(request, modelpack_id: int):
try:
if request.method == 'GET':
if not is_model_pack_loaded(modelpack_id):
get_medcat_from_model_pack_id(modelpack_id)
return Response('success', 200)
elif request.method == 'DELETE':
clear_cached_medcat_by_model_pack_id(modelpack_id)
return Response('success', 200)
else:
return Response(f'Invalid method', 404)
except ModelPack.DoesNotExist:
return Response(f'ModelPack with id:{modelpack_id} does not exist', 404)
except Exception as e:
return Response({'message': f'{str(e)}'}, 500)



@api_view(http_method_names=['GET'])
def model_loaded(_):
Expand Down
3 changes: 2 additions & 1 deletion medcat-trainer/webapp/api/core/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
path('api/project-progress/', api.views.project_progress),
path('api/concept-db-search-index-created/', api.views.concept_search_index_available),
path('api/model-loaded/', api.views.model_loaded),
path('api/cache-model/<int:project_id>/', api.views.cache_model),
path('api/cache-project-model/<int:project_id>/', api.views.cache_project_model),
path('api/cache-modelpack/<int:modelpack_id>/', api.views.cache_modelpack),
path('api/upload-deployment/', api.views.upload_deployment),
path('api/model-concept-children/<int:cdb_id>/', api.views.cdb_cui_children),
path('api/metrics/<int:report_id>/', api.views.view_metrics),
Expand Down
6 changes: 6 additions & 0 deletions medcat-trainer/webapp/frontend/env.d.ts
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
/// <reference types="vite/client" />

declare module '*.vue' {
import type { DefineComponent } from 'vue'
const component: DefineComponent<object, object, any>
export default component
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ export default {
this.loading = false
this.errorMessage = err.response.data.message || 'Error loading model.'
})

},
cancel () {
this.$emit('request:addAnnotationComplete')
Expand Down
Loading