Skip to content

Commit 210ced1

Browse files
jonorthwashvykliuk
andauthored
Embeddings (#252)
Adding support for embeddings mode. --------- Authored-by: vykliuk <[email protected]> Co-authored-by: vykliuk <[email protected]>
1 parent 175e521 commit 210ced1

File tree

7 files changed

+110
-17
lines changed

7 files changed

+110
-17
lines changed

apertium_apy/apy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
BillookupHandler,
4444
BilsearchHandler,
4545
CoverageHandler,
46+
EmbeddingsHandler,
4647
GenerateHandler,
4748
GuesserHandler,
4849
IdentifyLangHandler,
@@ -142,6 +143,8 @@ def setup_handler(
142143
handler.bilsearch[lang_pair] = (dirpath, modename)
143144
for dirpath, modename, lang_pair in modes['billookup']:
144145
handler.billookup[lang_pair] = (dirpath, modename)
146+
for dirpath, modename, lang_pair in modes['embeddings']:
147+
handler.embeddings[lang_pair] = (dirpath, modename)
145148

146149
handler.init_pairs_graph()
147150
handler.init_paths()
@@ -293,6 +296,7 @@ def setup_application(args):
293296
(r'/pipedebug', PipeDebugHandler),
294297
(r'/bilsearch', BilsearchHandler),
295298
(r'/billookup', BillookupHandler),
299+
(r'/embeddings', EmbeddingsHandler),
296300
] # type: List[Tuple[str, Type[tornado.web.RequestHandler]]]
297301

298302
if importlib_util.find_spec('streamparser'):

apertium_apy/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@
1919
from apertium_apy.handlers.translate_webpage import TranslateWebpageHandler # noqa: F401
2020
from apertium_apy.handlers.bilsearch import BilsearchHandler # noqa: F401
2121
from apertium_apy.handlers.billookup import BillookupHandler # noqa: F401
22+
from apertium_apy.handlers.embeddings import EmbeddingsHandler # noqa: F401

apertium_apy/handlers/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class BaseHandler(tornado.web.RequestHandler):
3838
guessers = {} # type: Dict[str, Tuple[str, str]]
3939
bilsearch = {}
4040
billookup = {}
41+
embeddings = {}
4142
pairprefs = {} # type: Dict[str, Dict[str, Dict[str, str]]]
4243
# (l1, l2): [translation.Pipeline], only contains flushing pairs!
4344
pipelines = {} # type: Dict[Tuple[str, str], List[Union[FlushingPipeline, SimplePipeline]]]

apertium_apy/handlers/billookup.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,34 +58,50 @@ def normalize(form):
5858
word = form.split("<", 1)[0]
5959
tags = re.findall(r"<([^>]+)>", form)
6060
if not tags:
61-
return word
61+
return word, []
6262
pos = tags[0]
6363
subcats = allowed.get(pos, [])
6464
filtered = []
65+
extra = []
6566
for t in tags[1:]:
66-
if t in subcats and t not in filtered:
67-
filtered.append(t)
67+
if t in subcats:
68+
if t not in filtered:
69+
filtered.append(t)
70+
else:
71+
tag_with_brackets = f"<{t}>"
72+
if tag_with_brackets not in extra:
73+
extra.append(tag_with_brackets)
6874
tag_str = f"<{pos}>" + "".join(f"<{t}>" for t in filtered)
69-
return f"{word}{tag_str}"
75+
return f"{word}{tag_str}", extra
7076

7177
consolidated = {}
7278
for item in raw_results:
7379
for src in item:
74-
norm_src = normalize(src)
80+
norm_src, extra_src = normalize(src)
7581
if norm_src not in consolidated:
76-
consolidated[norm_src] = []
82+
consolidated[norm_src] = {"targets": {}, "extra_tags": []}
83+
for tag in extra_src:
84+
if tag not in consolidated[norm_src]["extra_tags"]:
85+
consolidated[norm_src]["extra_tags"].append(tag)
7786
for tgt in item[src]:
78-
norm_tgt = normalize(tgt)
79-
if norm_tgt not in consolidated[norm_src]:
80-
consolidated[norm_src].append(norm_tgt)
87+
norm_tgt, _ = normalize(tgt)
88+
if norm_tgt not in consolidated[norm_src]["targets"]:
89+
consolidated[norm_src]["targets"][norm_tgt] = True
8190

8291
results = []
83-
for src, tgts in consolidated.items():
84-
tgt_list = []
85-
for t in tgts:
86-
tgt_list.append(t)
87-
entry = {}
88-
entry[src] = tgt_list
92+
for src, data in consolidated.items():
93+
tgt_list = list(data["targets"].keys())
94+
if data["extra_tags"]:
95+
extra_combined = "".join(data["extra_tags"])
96+
entry = {
97+
src: tgt_list,
98+
"extra-tags": [extra_combined]
99+
}
100+
else:
101+
entry = {
102+
src: tgt_list,
103+
"extra-tags": []
104+
}
89105
results.append(entry)
90106

91107
self.send_response({
@@ -102,6 +118,6 @@ def normalize(form):
102118
@gen.coroutine
103119
def get(self):
104120
pair = self.get_pair_or_error(self.get_argument('langpair'))
105-
121+
106122
if pair is not None:
107123
yield self.lookup_and_respond(pair, self.get_argument('q'))
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import logging
2+
import os
3+
from tornado import gen
4+
5+
from apertium_apy.handlers.base import BaseHandler
6+
from apertium_apy.utils.translation import translate_simple
7+
from apertium_apy.utils import to_alpha3_code
8+
9+
class EmbeddingsHandler(BaseHandler):
10+
def get_pair_or_error(self, langpair):
11+
try:
12+
l1, l2 = map(to_alpha3_code, langpair.split('|'))
13+
in_mode = f"{l1}-{l2}"
14+
except ValueError:
15+
self.send_error(400, explanation='That pair is invalid, use e.g. eng|spa')
16+
return None
17+
18+
in_mode = self.find_fallback_mode(in_mode, self.pairs)
19+
if in_mode not in self.pairs:
20+
self.send_error(400, explanation='That pair is not installed')
21+
return None
22+
return tuple(in_mode.split('-'))
23+
24+
@gen.coroutine
25+
def embed_and_respond(self, pair, query):
26+
try:
27+
raw_path, mode = self.embeddings["-".join(pair)]
28+
path = os.path.abspath(raw_path)
29+
30+
commands = [['apertium', '-d', path, '-f', 'none', mode]]
31+
raw = yield translate_simple(query + '\n', commands)
32+
33+
segments = [seg for seg in raw.strip().split('^') if seg and '/' in seg]
34+
results = []
35+
36+
for seg in segments:
37+
entry = seg.rstrip('~').rstrip('$')
38+
parts = entry.split('/')
39+
src = parts[0]
40+
forms = []
41+
for t in parts[1:]:
42+
t_clean = t.split('~', 1)[0]
43+
if t_clean:
44+
forms.append(t_clean)
45+
if forms:
46+
results.append({src: forms})
47+
48+
self.send_response({
49+
'responseData': {'embeddingResults': results},
50+
'responseDetails': None,
51+
'responseStatus': 200,
52+
})
53+
except Exception:
54+
logging.exception('Embedding error in %s-%s', *pair)
55+
self.send_error(503, explanation='internal error')
56+
57+
@gen.coroutine
58+
def get(self):
59+
pair = self.get_pair_or_error(self.get_argument('langpair'))
60+
if pair is not None:
61+
yield self.embed_and_respond(pair, self.get_argument('q'))

apertium_apy/handlers/list_modes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def get(self):
4747
if self.get_arguments('include_deprecated_codes'):
4848
response_data.append({'sourceLanguage': to_alpha2_code(l1), 'targetLanguage': to_alpha2_code(l2)})
4949
self.send_response({'responseData': response_data, 'responseDetails': None, 'responseStatus': 200})
50+
elif query == 'embeddings':
51+
response_data = []
52+
for pair, (path, modename) in self.embeddings.items():
53+
l1, l2 = pair.split('-')
54+
response_data.append({'sourceLanguage': l1, 'targetLanguage': l2})
55+
if self.get_arguments('include_deprecated_codes'):
56+
response_data.append({'sourceLanguage': to_alpha2_code(l1), 'targetLanguage': to_alpha2_code(l2)})
57+
self.send_response({'responseData': response_data, 'responseDetails': None, 'responseStatus': 200})
5058

5159
else:
5260
self.send_error(400, explanation='Expecting q argument to be one of analysers, generators, guessers, spellers, disambiguators, or pairs')

apertium_apy/mode_search.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def search_path(rootpath, include_pairs=True, verbosity=1):
4444
'guesser': re.compile(r'(({0}(-{0})?)-guess(er)?)\.mode'.format(lang_code)),
4545
'bilsearch': re.compile(r'({0})-({0})-bilsearch\.mode'.format(lang_code)),
4646
'billookup': re.compile(r'({0})-({0})-billookup\.mode'.format(lang_code)),
47+
'embeddings': re.compile(r'({0})-({0})-embeddings\.mode'.format(lang_code)),
4748
}
4849
modes = {
4950
'pair': [],
@@ -55,6 +56,7 @@ def search_path(rootpath, include_pairs=True, verbosity=1):
5556
'guesser': [],
5657
'bilsearch': [],
5758
'billookup': [],
59+
'embeddings': [],
5860
} # type: Dict[str, List[Tuple[str, str, str]]]
5961

6062
real_root = os.path.abspath(os.path.realpath(rootpath))
@@ -67,7 +69,7 @@ def search_path(rootpath, include_pairs=True, verbosity=1):
6769
for mtype, regex in type_re.items():
6870
m = regex.match(filename)
6971
if m:
70-
if mtype == 'bilsearch' or mtype == 'billookup':
72+
if mtype == 'bilsearch' or mtype == 'billookup' or mtype == 'embeddings':
7173
lang_src = to_alpha3_code(m.group(1))
7274
lang_trg = to_alpha3_code(m.group(2))
7375
lang_pair = f"{lang_src}-{lang_trg}"

0 commit comments

Comments
 (0)