Skip to content

Commit 41f26a3

Browse files
authored
Added RobertaTokenizer (#188)
* Added RobertaTokenizer * Made parsing more restrictive, fixed tests
1 parent b71fb0f commit 41f26a3

File tree

5 files changed

+82
-31
lines changed

5 files changed

+82
-31
lines changed

Sources/Hub/Hub.swift

+9-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,15 @@ public struct Config {
133133
}
134134

135135
/// Tuple of token identifier and string value
136-
public var tokenValue: (UInt, String)? { value as? (UInt, String) }
136+
public var tokenValue: (UInt, String)? {
137+
guard let value = value as? [Any] else {
138+
return nil
139+
}
140+
guard let stringValue = value.first as? String, let intValue = value.dropFirst().first as? UInt else {
141+
return nil
142+
}
143+
return (intValue, stringValue)
144+
}
137145
}
138146

139147
public class LanguageModelConfigurationFromHub {

Sources/Tokenizers/Tokenizer.swift

+2-1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct TokenizerModel {
101101
"BertTokenizer": BertTokenizer.self,
102102
"DistilbertTokenizer": BertTokenizer.self,
103103
"DistilBertTokenizer": BertTokenizer.self,
104+
"RobertaTokenizer": BPETokenizer.self,
104105
"CodeGenTokenizer": CodeGenTokenizer.self,
105106
"CodeLlamaTokenizer": CodeLlamaTokenizer.self,
106107
"FalconTokenizer": FalconTokenizer.self,
@@ -230,7 +231,7 @@ public extension Tokenizer {
230231
func callAsFunction(_ text: String, addSpecialTokens: Bool = true) -> [Int] {
231232
encode(text: text, addSpecialTokens: addSpecialTokens)
232233
}
233-
234+
234235
func decode(tokens: [Int]) -> String {
235236
decode(tokens: tokens, skipSpecialTokens: false)
236237
}

Tests/HubTests/HubTests.swift

+14
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,18 @@ class HubTests: XCTestCase {
117117
let vocab_dict = config.dictionary["vocab"] as! [String: Int]
118118
XCTAssertNotEqual(vocab_dict.count, 2)
119119
}
120+
121+
func testConfigTokenValue() throws {
122+
let config1 = Config(["cls": ["str" as String, 100 as UInt] as [Any]])
123+
let tokenValue1 = config1.cls?.tokenValue
124+
XCTAssertEqual(tokenValue1?.0, 100)
125+
XCTAssertEqual(tokenValue1?.1, "str")
126+
127+
let data = #"{"cls": ["str", 100]}"#.data(using: .utf8)!
128+
let dict = try JSONSerialization.jsonObject(with: data, options: []) as! [NSString: Any]
129+
let config2 = Config(dict)
130+
let tokenValue2 = config2.cls?.tokenValue
131+
XCTAssertEqual(tokenValue2?.0, 100)
132+
XCTAssertEqual(tokenValue2?.1, "str")
133+
}
120134
}

Tests/PostProcessorTests/PostProcessorTests.swift

+12-12
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ class PostProcessorTests: XCTestCase {
77
let testCases: [(Config, [String], [String]?, [String])] = [
88
// Should keep spaces; uneven spaces; ignore `addPrefixSpace`.
99
(
10-
Config(["cls": (0, "[HEAD]") as (UInt, String),
11-
"sep": (0, "[END]") as (UInt, String),
10+
Config(["cls": ["[HEAD]", 0 as UInt],
11+
"sep": ["[END]", 0 as UInt],
1212
"trimOffset": false,
1313
"addPrefixSpace": true]),
1414
[" The", " sun", "sets ", " in ", " the ", "west"],
@@ -17,8 +17,8 @@ class PostProcessorTests: XCTestCase {
1717
),
1818
// Should leave only one space around each token.
1919
(
20-
Config(["cls": (0, "[START]") as (UInt, String),
21-
"sep": (0, "[BREAK]") as (UInt, String),
20+
Config(["cls": ["[START]", 0 as UInt],
21+
"sep": ["[BREAK]", 0 as UInt],
2222
"trimOffset": true,
2323
"addPrefixSpace": true]),
2424
[" The ", " sun", "sets ", " in ", " the ", "west"],
@@ -27,8 +27,8 @@ class PostProcessorTests: XCTestCase {
2727
),
2828
// Should ignore empty tokens pair.
2929
(
30-
Config(["cls": (0, "[START]") as (UInt, String),
31-
"sep": (0, "[BREAK]") as (UInt, String),
30+
Config(["cls": ["[START]", 0 as UInt],
31+
"sep": ["[BREAK]", 0 as UInt],
3232
"trimOffset": true,
3333
"addPrefixSpace": true]),
3434
[" The ", " sun", "sets ", " in ", " the ", "west"],
@@ -37,8 +37,8 @@ class PostProcessorTests: XCTestCase {
3737
),
3838
// Should trim all whitespace.
3939
(
40-
Config(["cls": (0, "[CLS]") as (UInt, String),
41-
"sep": (0, "[SEP]") as (UInt, String),
40+
Config(["cls": ["[CLS]", 0 as UInt],
41+
"sep": ["[SEP]", 0 as UInt],
4242
"trimOffset": true,
4343
"addPrefixSpace": false]),
4444
[" The ", " sun", "sets ", " in ", " the ", "west"],
@@ -47,8 +47,8 @@ class PostProcessorTests: XCTestCase {
4747
),
4848
// Should add tokens.
4949
(
50-
Config(["cls": (0, "[CLS]") as (UInt, String),
51-
"sep": (0, "[SEP]") as (UInt, String),
50+
Config(["cls": ["[CLS]", 0 as UInt],
51+
"sep": ["[SEP]", 0 as UInt],
5252
"trimOffset": true,
5353
"addPrefixSpace": true]),
5454
[" The ", " sun", "sets ", " in ", " the ", "west"],
@@ -58,8 +58,8 @@ class PostProcessorTests: XCTestCase {
5858
"mat", "[SEP]"]
5959
),
6060
(
61-
Config(["cls": (0, "[CLS]") as (UInt, String),
62-
"sep": (0, "[SEP]") as (UInt, String),
61+
Config(["cls": ["[CLS]", 0 as UInt],
62+
"sep": ["[SEP]", 0 as UInt],
6363
"trimOffset": true,
6464
"addPrefixSpace": true]),
6565
["", "", ","],

Tests/TokenizersTests/TokenizerTests.swift

+45-17
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,34 @@ class BertSpacesTests: XCTestCase {
212212
}
213213
}
214214

215+
class RobertaTests: XCTestCase {
216+
func testEncodeDecode() async throws {
217+
guard let tokenizer = try await AutoTokenizer.from(pretrained: "FacebookAI/roberta-base") as? PreTrainedTokenizer else {
218+
XCTFail()
219+
return
220+
}
221+
222+
XCTAssertEqual(tokenizer.tokenize(text: "l'eure"), ["l", "'", "e", "ure"])
223+
XCTAssertEqual(tokenizer.encode(text: "l'eure"), [0, 462, 108, 242, 2407, 2])
224+
XCTAssertEqual(tokenizer.decode(tokens: tokenizer.encode(text: "l'eure"), skipSpecialTokens: true), "l'eure")
225+
226+
XCTAssertEqual(tokenizer.tokenize(text: "mąka"), ["m", "Ä", "ħ", "ka"])
227+
XCTAssertEqual(tokenizer.encode(text: "mąka"), [0, 119, 649, 5782, 2348, 2])
228+
229+
XCTAssertEqual(tokenizer.tokenize(text: "département"), ["d", "é", "part", "ement"])
230+
XCTAssertEqual(tokenizer.encode(text: "département"), [0, 417, 1140, 7755, 6285, 2])
231+
232+
XCTAssertEqual(tokenizer.tokenize(text: "Who are you?"), ["Who", "Ġare", "Ġyou", "?"])
233+
XCTAssertEqual(tokenizer.encode(text: "Who are you?"), [0, 12375, 32, 47, 116, 2])
234+
235+
XCTAssertEqual(tokenizer.tokenize(text: " Who are you? "), ["ĠWho", "Ġare", "Ġyou", "?", "Ġ"])
236+
XCTAssertEqual(tokenizer.encode(text: " Who are you? "), [0, 3394, 32, 47, 116, 1437, 2])
237+
238+
XCTAssertEqual(tokenizer.tokenize(text: "<s>Who are you?</s>"), ["<s>", "Who", "Ġare", "Ġyou", "?", "</s>"])
239+
XCTAssertEqual(tokenizer.encode(text: "<s>Who are you?</s>"), [0, 0, 12375, 32, 47, 116, 2, 2])
240+
}
241+
}
242+
215243
struct EncodedTokenizerSamplesDataset: Decodable {
216244
let text: String
217245
// Bad naming, not just for bpe.
@@ -239,16 +267,16 @@ struct EncodedData: Decodable {
239267
class TokenizerTester {
240268
let encodedSamplesFilename: String
241269
let unknownTokenId: Int?
242-
270+
243271
private var configuration: LanguageModelConfigurationFromHub?
244272
private var edgeCases: [EdgeCase]?
245273
private var _tokenizer: Tokenizer?
246-
274+
247275
init(hubModelName: String, encodedSamplesFilename: String, unknownTokenId: Int?, hubApi: HubApi) {
248276
configuration = LanguageModelConfigurationFromHub(modelName: hubModelName, hubApi: hubApi)
249277
self.encodedSamplesFilename = encodedSamplesFilename
250278
self.unknownTokenId = unknownTokenId
251-
279+
252280
// Read the edge cases dataset
253281
edgeCases = {
254282
let url = Bundle.module.url(forResource: "tokenizer_tests", withExtension: "json")!
@@ -259,15 +287,15 @@ class TokenizerTester {
259287
return cases[hubModelName]
260288
}()
261289
}
262-
290+
263291
lazy var dataset: EncodedTokenizerSamplesDataset = {
264292
let url = Bundle.module.url(forResource: encodedSamplesFilename, withExtension: "json")!
265293
let json = try! Data(contentsOf: url)
266294
let decoder = JSONDecoder()
267295
let dataset = try! decoder.decode(EncodedTokenizerSamplesDataset.self, from: json)
268296
return dataset
269297
}()
270-
298+
271299
var tokenizer: Tokenizer? {
272300
get async {
273301
guard _tokenizer == nil else { return _tokenizer! }
@@ -283,39 +311,39 @@ class TokenizerTester {
283311
return _tokenizer
284312
}
285313
}
286-
314+
287315
var tokenizerModel: TokenizingModel? {
288316
get async {
289317
// The model is not usually accessible; maybe it should
290318
guard let tokenizer = await tokenizer else { return nil }
291319
return (tokenizer as! PreTrainedTokenizer).model
292320
}
293321
}
294-
322+
295323
func testTokenize() async {
296324
let tokenized = await tokenizer?.tokenize(text: dataset.text)
297325
XCTAssertEqual(
298326
tokenized,
299327
dataset.bpe_tokens
300328
)
301329
}
302-
330+
303331
func testEncode() async {
304332
let encoded = await tokenizer?.encode(text: dataset.text)
305333
XCTAssertEqual(
306334
encoded,
307335
dataset.token_ids
308336
)
309337
}
310-
338+
311339
func testDecode() async {
312340
let decoded = await tokenizer?.decode(tokens: dataset.token_ids)
313341
XCTAssertEqual(
314342
decoded,
315343
dataset.decoded_text
316344
)
317345
}
318-
346+
319347
/// Test encode and decode for a few edge cases
320348
func testEdgeCases() async {
321349
guard let edgeCases else {
@@ -339,7 +367,7 @@ class TokenizerTester {
339367
)
340368
}
341369
}
342-
370+
343371
func testUnknownToken() async {
344372
guard let model = await tokenizerModel else { return }
345373
XCTAssertEqual(model.unknownTokenId, unknownTokenId)
@@ -361,10 +389,10 @@ class TokenizerTester {
361389
class TokenizerTests: XCTestCase {
362390
/// Parallel testing in Xcode (when enabled) uses different processes, so this shouldn't be a problem
363391
static var _tester: TokenizerTester? = nil
364-
392+
365393
class var hubModelName: String? { nil }
366394
class var encodedSamplesFilename: String? { nil }
367-
395+
368396
/// Known id retrieved from Python, to verify it was parsed correctly
369397
class var unknownTokenId: Int? { nil }
370398

@@ -399,25 +427,25 @@ class TokenizerTests: XCTestCase {
399427
await tester.testTokenize()
400428
}
401429
}
402-
430+
403431
func testEncode() async {
404432
if let tester = Self._tester {
405433
await tester.testEncode()
406434
}
407435
}
408-
436+
409437
func testDecode() async {
410438
if let tester = Self._tester {
411439
await tester.testDecode()
412440
}
413441
}
414-
442+
415443
func testEdgeCases() async {
416444
if let tester = Self._tester {
417445
await tester.testEdgeCases()
418446
}
419447
}
420-
448+
421449
func testUnknownToken() async {
422450
if let tester = Self._tester {
423451
await tester.testUnknownToken()

0 commit comments

Comments
 (0)