@@ -212,6 +212,34 @@ class BertSpacesTests: XCTestCase {
212
212
}
213
213
}
214
214
215
+ class RobertaTests : XCTestCase {
216
+ func testEncodeDecode( ) async throws {
217
+ guard let tokenizer = try await AutoTokenizer . from ( pretrained: " FacebookAI/roberta-base " ) as? PreTrainedTokenizer else {
218
+ XCTFail ( )
219
+ return
220
+ }
221
+
222
+ XCTAssertEqual ( tokenizer. tokenize ( text: " l'eure " ) , [ " l " , " ' " , " e " , " ure " ] )
223
+ XCTAssertEqual ( tokenizer. encode ( text: " l'eure " ) , [ 0 , 462 , 108 , 242 , 2407 , 2 ] )
224
+ XCTAssertEqual ( tokenizer. decode ( tokens: tokenizer. encode ( text: " l'eure " ) , skipSpecialTokens: true ) , " l'eure " )
225
+
226
+ XCTAssertEqual ( tokenizer. tokenize ( text: " mąka " ) , [ " m " , " Ä " , " ħ " , " ka " ] )
227
+ XCTAssertEqual ( tokenizer. encode ( text: " mąka " ) , [ 0 , 119 , 649 , 5782 , 2348 , 2 ] )
228
+
229
+ XCTAssertEqual ( tokenizer. tokenize ( text: " département " ) , [ " d " , " é " , " part " , " ement " ] )
230
+ XCTAssertEqual ( tokenizer. encode ( text: " département " ) , [ 0 , 417 , 1140 , 7755 , 6285 , 2 ] )
231
+
232
+ XCTAssertEqual ( tokenizer. tokenize ( text: " Who are you? " ) , [ " Who " , " Ġare " , " Ġyou " , " ? " ] )
233
+ XCTAssertEqual ( tokenizer. encode ( text: " Who are you? " ) , [ 0 , 12375 , 32 , 47 , 116 , 2 ] )
234
+
235
+ XCTAssertEqual ( tokenizer. tokenize ( text: " Who are you? " ) , [ " ĠWho " , " Ġare " , " Ġyou " , " ? " , " Ġ " ] )
236
+ XCTAssertEqual ( tokenizer. encode ( text: " Who are you? " ) , [ 0 , 3394 , 32 , 47 , 116 , 1437 , 2 ] )
237
+
238
+ XCTAssertEqual ( tokenizer. tokenize ( text: " <s>Who are you?</s> " ) , [ " <s> " , " Who " , " Ġare " , " Ġyou " , " ? " , " </s> " ] )
239
+ XCTAssertEqual ( tokenizer. encode ( text: " <s>Who are you?</s> " ) , [ 0 , 0 , 12375 , 32 , 47 , 116 , 2 , 2 ] )
240
+ }
241
+ }
242
+
215
243
struct EncodedTokenizerSamplesDataset : Decodable {
216
244
let text : String
217
245
// Bad naming, not just for bpe.
@@ -239,16 +267,16 @@ struct EncodedData: Decodable {
239
267
class TokenizerTester {
240
268
let encodedSamplesFilename : String
241
269
let unknownTokenId : Int ?
242
-
270
+
243
271
private var configuration : LanguageModelConfigurationFromHub ?
244
272
private var edgeCases : [ EdgeCase ] ?
245
273
private var _tokenizer : Tokenizer ?
246
-
274
+
247
275
init ( hubModelName: String , encodedSamplesFilename: String , unknownTokenId: Int ? , hubApi: HubApi ) {
248
276
configuration = LanguageModelConfigurationFromHub ( modelName: hubModelName, hubApi: hubApi)
249
277
self . encodedSamplesFilename = encodedSamplesFilename
250
278
self . unknownTokenId = unknownTokenId
251
-
279
+
252
280
// Read the edge cases dataset
253
281
edgeCases = {
254
282
let url = Bundle . module. url ( forResource: " tokenizer_tests " , withExtension: " json " ) !
@@ -259,15 +287,15 @@ class TokenizerTester {
259
287
return cases [ hubModelName]
260
288
} ( )
261
289
}
262
-
290
+
263
291
lazy var dataset : EncodedTokenizerSamplesDataset = {
264
292
let url = Bundle . module. url ( forResource: encodedSamplesFilename, withExtension: " json " ) !
265
293
let json = try ! Data ( contentsOf: url)
266
294
let decoder = JSONDecoder ( )
267
295
let dataset = try ! decoder. decode ( EncodedTokenizerSamplesDataset . self, from: json)
268
296
return dataset
269
297
} ( )
270
-
298
+
271
299
var tokenizer : Tokenizer ? {
272
300
get async {
273
301
guard _tokenizer == nil else { return _tokenizer! }
@@ -283,39 +311,39 @@ class TokenizerTester {
283
311
return _tokenizer
284
312
}
285
313
}
286
-
314
+
287
315
var tokenizerModel : TokenizingModel ? {
288
316
get async {
289
317
// The model is not usually accessible; maybe it should
290
318
guard let tokenizer = await tokenizer else { return nil }
291
319
return ( tokenizer as! PreTrainedTokenizer ) . model
292
320
}
293
321
}
294
-
322
+
295
323
func testTokenize( ) async {
296
324
let tokenized = await tokenizer? . tokenize ( text: dataset. text)
297
325
XCTAssertEqual (
298
326
tokenized,
299
327
dataset. bpe_tokens
300
328
)
301
329
}
302
-
330
+
303
331
func testEncode( ) async {
304
332
let encoded = await tokenizer? . encode ( text: dataset. text)
305
333
XCTAssertEqual (
306
334
encoded,
307
335
dataset. token_ids
308
336
)
309
337
}
310
-
338
+
311
339
func testDecode( ) async {
312
340
let decoded = await tokenizer? . decode ( tokens: dataset. token_ids)
313
341
XCTAssertEqual (
314
342
decoded,
315
343
dataset. decoded_text
316
344
)
317
345
}
318
-
346
+
319
347
/// Test encode and decode for a few edge cases
320
348
func testEdgeCases( ) async {
321
349
guard let edgeCases else {
@@ -339,7 +367,7 @@ class TokenizerTester {
339
367
)
340
368
}
341
369
}
342
-
370
+
343
371
func testUnknownToken( ) async {
344
372
guard let model = await tokenizerModel else { return }
345
373
XCTAssertEqual ( model. unknownTokenId, unknownTokenId)
@@ -361,10 +389,10 @@ class TokenizerTester {
361
389
class TokenizerTests : XCTestCase {
362
390
/// Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem
363
391
static var _tester : TokenizerTester ? = nil
364
-
392
+
365
393
class var hubModelName : String ? { nil }
366
394
class var encodedSamplesFilename : String ? { nil }
367
-
395
+
368
396
/// Known id retrieved from Python, to verify it was parsed correctly
369
397
class var unknownTokenId : Int ? { nil }
370
398
@@ -399,25 +427,25 @@ class TokenizerTests: XCTestCase {
399
427
await tester. testTokenize ( )
400
428
}
401
429
}
402
-
430
+
403
431
func testEncode( ) async {
404
432
if let tester = Self . _tester {
405
433
await tester. testEncode ( )
406
434
}
407
435
}
408
-
436
+
409
437
func testDecode( ) async {
410
438
if let tester = Self . _tester {
411
439
await tester. testDecode ( )
412
440
}
413
441
}
414
-
442
+
415
443
func testEdgeCases( ) async {
416
444
if let tester = Self . _tester {
417
445
await tester. testEdgeCases ( )
418
446
}
419
447
}
420
-
448
+
421
449
func testUnknownToken( ) async {
422
450
if let tester = Self . _tester {
423
451
await tester. testUnknownToken ( )
0 commit comments