@@ -212,6 +212,34 @@ class BertSpacesTests: XCTestCase {
212212 }
213213}
214214
215+ class RobertaTests : XCTestCase {
216+ func testEncodeDecode( ) async throws {
217+ guard let tokenizer = try await AutoTokenizer . from ( pretrained: " FacebookAI/roberta-base " ) as? PreTrainedTokenizer else {
218+ XCTFail ( )
219+ return
220+ }
221+
222+ XCTAssertEqual ( tokenizer. tokenize ( text: " l'eure " ) , [ " l " , " ' " , " e " , " ure " ] )
223+ XCTAssertEqual ( tokenizer. encode ( text: " l'eure " ) , [ 0 , 462 , 108 , 242 , 2407 , 2 ] )
224+ XCTAssertEqual ( tokenizer. decode ( tokens: tokenizer. encode ( text: " l'eure " ) , skipSpecialTokens: true ) , " l'eure " )
225+
226+ XCTAssertEqual ( tokenizer. tokenize ( text: " mąka " ) , [ " m " , " Ä " , " ħ " , " ka " ] )
227+ XCTAssertEqual ( tokenizer. encode ( text: " mąka " ) , [ 0 , 119 , 649 , 5782 , 2348 , 2 ] )
228+
229+ XCTAssertEqual ( tokenizer. tokenize ( text: " département " ) , [ " d " , " é " , " part " , " ement " ] )
230+ XCTAssertEqual ( tokenizer. encode ( text: " département " ) , [ 0 , 417 , 1140 , 7755 , 6285 , 2 ] )
231+
232+ XCTAssertEqual ( tokenizer. tokenize ( text: " Who are you? " ) , [ " Who " , " Ġare " , " Ġyou " , " ? " ] )
233+ XCTAssertEqual ( tokenizer. encode ( text: " Who are you? " ) , [ 0 , 12375 , 32 , 47 , 116 , 2 ] )
234+
235+ XCTAssertEqual ( tokenizer. tokenize ( text: " Who are you? " ) , [ " ĠWho " , " Ġare " , " Ġyou " , " ? " , " Ġ " ] )
236+ XCTAssertEqual ( tokenizer. encode ( text: " Who are you? " ) , [ 0 , 3394 , 32 , 47 , 116 , 1437 , 2 ] )
237+
238+ XCTAssertEqual ( tokenizer. tokenize ( text: " <s>Who are you?</s> " ) , [ " <s> " , " Who " , " Ġare " , " Ġyou " , " ? " , " </s> " ] )
239+ XCTAssertEqual ( tokenizer. encode ( text: " <s>Who are you?</s> " ) , [ 0 , 0 , 12375 , 32 , 47 , 116 , 2 , 2 ] )
240+ }
241+ }
242+
215243struct EncodedTokenizerSamplesDataset : Decodable {
216244 let text : String
217245 // Bad naming, not just for bpe.
@@ -239,16 +267,16 @@ struct EncodedData: Decodable {
239267class TokenizerTester {
240268 let encodedSamplesFilename : String
241269 let unknownTokenId : Int ?
242-
270+
243271 private var configuration : LanguageModelConfigurationFromHub ?
244272 private var edgeCases : [ EdgeCase ] ?
245273 private var _tokenizer : Tokenizer ?
246-
274+
247275 init ( hubModelName: String , encodedSamplesFilename: String , unknownTokenId: Int ? , hubApi: HubApi ) {
248276 configuration = LanguageModelConfigurationFromHub ( modelName: hubModelName, hubApi: hubApi)
249277 self . encodedSamplesFilename = encodedSamplesFilename
250278 self . unknownTokenId = unknownTokenId
251-
279+
252280 // Read the edge cases dataset
253281 edgeCases = {
254282 let url = Bundle . module. url ( forResource: " tokenizer_tests " , withExtension: " json " ) !
@@ -259,15 +287,15 @@ class TokenizerTester {
259287 return cases [ hubModelName]
260288 } ( )
261289 }
262-
290+
263291 lazy var dataset : EncodedTokenizerSamplesDataset = {
264292 let url = Bundle . module. url ( forResource: encodedSamplesFilename, withExtension: " json " ) !
265293 let json = try ! Data ( contentsOf: url)
266294 let decoder = JSONDecoder ( )
267295 let dataset = try ! decoder. decode ( EncodedTokenizerSamplesDataset . self, from: json)
268296 return dataset
269297 } ( )
270-
298+
271299 var tokenizer : Tokenizer ? {
272300 get async {
273301 guard _tokenizer == nil else { return _tokenizer! }
@@ -283,39 +311,39 @@ class TokenizerTester {
283311 return _tokenizer
284312 }
285313 }
286-
314+
287315 var tokenizerModel : TokenizingModel ? {
288316 get async {
289317 // The model is not usually accessible; maybe it should
290318 guard let tokenizer = await tokenizer else { return nil }
291319 return ( tokenizer as! PreTrainedTokenizer ) . model
292320 }
293321 }
294-
322+
295323 func testTokenize( ) async {
296324 let tokenized = await tokenizer? . tokenize ( text: dataset. text)
297325 XCTAssertEqual (
298326 tokenized,
299327 dataset. bpe_tokens
300328 )
301329 }
302-
330+
303331 func testEncode( ) async {
304332 let encoded = await tokenizer? . encode ( text: dataset. text)
305333 XCTAssertEqual (
306334 encoded,
307335 dataset. token_ids
308336 )
309337 }
310-
338+
311339 func testDecode( ) async {
312340 let decoded = await tokenizer? . decode ( tokens: dataset. token_ids)
313341 XCTAssertEqual (
314342 decoded,
315343 dataset. decoded_text
316344 )
317345 }
318-
346+
319347 /// Test encode and decode for a few edge cases
320348 func testEdgeCases( ) async {
321349 guard let edgeCases else {
@@ -339,7 +367,7 @@ class TokenizerTester {
339367 )
340368 }
341369 }
342-
370+
343371 func testUnknownToken( ) async {
344372 guard let model = await tokenizerModel else { return }
345373 XCTAssertEqual ( model. unknownTokenId, unknownTokenId)
@@ -361,10 +389,10 @@ class TokenizerTester {
361389class TokenizerTests : XCTestCase {
362390 /// Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem
363391 static var _tester : TokenizerTester ? = nil
364-
392+
365393 class var hubModelName : String ? { nil }
366394 class var encodedSamplesFilename : String ? { nil }
367-
395+
368396 /// Known id retrieved from Python, to verify it was parsed correctly
369397 class var unknownTokenId : Int ? { nil }
370398
@@ -399,25 +427,25 @@ class TokenizerTests: XCTestCase {
399427 await tester. testTokenize ( )
400428 }
401429 }
402-
430+
403431 func testEncode( ) async {
404432 if let tester = Self . _tester {
405433 await tester. testEncode ( )
406434 }
407435 }
408-
436+
409437 func testDecode( ) async {
410438 if let tester = Self . _tester {
411439 await tester. testDecode ( )
412440 }
413441 }
414-
442+
415443 func testEdgeCases( ) async {
416444 if let tester = Self . _tester {
417445 await tester. testEdgeCases ( )
418446 }
419447 }
420-
448+
421449 func testUnknownToken( ) async {
422450 if let tester = Self . _tester {
423451 await tester. testUnknownToken ( )
0 commit comments