Skip to content

Commit 55bcb5d

Browse files
RockingJavaBeanfacebook-github-bot
authored andcommittedJun 17, 2020
Fix inconsistent results of string split func on JIT mode (pytorch#38772)
Summary: Resolve pytorch#38207 Below is the description of split function according to [Python doc](https://docs.python.org/3.8/library/stdtypes.html?highlight=split#str.split). ``` If sep is not specified or is None, a different splitting algorithm is applied: runs of consecutive whitespace are regarded as a single separator, and the result will contain no empty strings at the start or end if the string has leading or trailing whitespace. ``` The logic to handle both none and empty separators is added in register_string_ops.cpp as fix. Signed-off-by: Xiong Wei <[email protected]> Pull Request resolved: pytorch#38772 Differential Revision: D21789612 Pulled By: suo fbshipit-source-id: 4dfd74eda71e0bfd757378daedc927a4a63ec0e4
1 parent 5e77999 commit 55bcb5d

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed
 

‎aten/src/ATen/templates/TypeDefault.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ TORCH_LIBRARY(aten, m) {
6262
m.def("replace(str self, str old, str new, int max=-1) -> str");
6363
m.def("partition(str self, str separator) -> (str, str, str)");
6464
m.def("rpartition(str self, str separator) -> (str, str, str)");
65-
m.def("split.str(str self, str separator=' ', int max=-1) -> str[]");
65+
m.def("split.str(str self, str? separator=None, int max=-1) -> str[]");
6666
m.def("rsplit(str self, str separator=' ', int max=-1) -> str[]");
6767
m.def("join(str self, str[] values) -> str");
6868

‎test/backward_compatibility/check_backward_compatibility.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
('aten::__and__', datetime.date(2020, 6, 30)),
102102
('aten::__or__', datetime.date(2020, 6, 30)),
103103
('aten::__xor__', datetime.date(2020, 6, 30)),
104+
('aten::split', datetime.date(2020, 6, 30)),
104105
]
105106

106107

‎test/test_jit_string.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,14 @@ def test_rpartition():
263263
self.checkScript(test_rpartition, ())
264264

265265
def test_split():
266-
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
266+
"""
267+
type: () -> Tuple[List[str], List[str], List[str], List[str], List[str],
268+
List[str], List[str], List[str], List[str], List[str], List[str]]
269+
"""
267270
return (
268271
"a a a a a".split(),
272+
"a a a a a".split(),
273+
" a a\ta \v a \v\f\n a \t ".split(),
269274
" a a a a a ".split(" "),
270275
"a a a a a ".split(" ", 10),
271276
"a a a a a ".split(" ", -1),
@@ -277,6 +282,14 @@ def test_split():
277282
)
278283
self.checkScript(test_split, ())
279284

285+
# test raising error for empty separator
286+
def test_split_empty_separator():
287+
s = "test"
288+
return s.split("")
289+
290+
self.checkScriptRaisesRegex(test_split_empty_separator, (), Exception,
291+
"empty separator")
292+
280293
def test_rsplit():
281294
# type: () -> Tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str], List[str]]
282295
return (

‎torch/csrc/jit/runtime/register_string_ops.cpp

+45-3
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,33 @@ RegisterOperators reg_str_ops({
129129

130130
});
131131

132+
// consecutive whitespace are regarded as a single separator,
133+
// the result will contain no empty strings at the start or end
134+
// if the string has leading or trailing whitespace.
135+
c10::List<std::string> splitNoneSeparator(const std::string& string) {
136+
c10::List<std::string> splits;
137+
// whitespaces includes tab, space and
138+
// the delimiters defined in the implementation of splitlines
139+
std::string whitespaces =
140+
" \t\n\r\r\n\v\x0b\f\x0c\x1c\x1d\x1e\x85\u2028\u2029";
141+
std::string::size_type prev_pos = 0;
142+
std::string::size_type pos = 0;
143+
144+
while ((pos = string.find_first_of(whitespaces, pos)) != std::string::npos) {
145+
auto substr = string.substr(prev_pos, pos - prev_pos);
146+
// skip the whitespaces as the Python split() method
147+
if (!substr.empty()) {
148+
splits.emplace_back(substr);
149+
}
150+
pos++;
151+
prev_pos = pos;
152+
}
153+
if (prev_pos != string.size()) {
154+
splits.emplace_back(string.substr(prev_pos));
155+
}
156+
return splits;
157+
}
158+
132159
// String Ops
133160
// Implementations located in torch/csrc/jit/runtime/register_string_ops.cpp
134161
TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
@@ -546,19 +573,34 @@ TORCH_LIBRARY_IMPL(aten, CatchAll, m) {
546573
});
547574

548575
m.impl(
549-
"split.str", [](std::string string, std::string separator, int64_t max) {
576+
"split.str",
577+
[](const std::string& string,
578+
c10::optional<std::string> separator,
579+
int64_t max) {
580+
if (!separator.has_value()) {
581+
// if separator is not specified,
582+
// a different splitting algorithm is applied as Python
583+
return splitNoneSeparator(string);
584+
;
585+
}
586+
if (separator.value().empty()) {
587+
throw std::runtime_error("ValueError: empty separator");
588+
}
589+
550590
std::string::size_type prev_pos = 0;
551591
std::string::size_type pos = 0;
552592
c10::List<std::string> splits;
553593
auto count = 0;
554-
while ((pos = string.find(separator, pos)) != std::string::npos) {
594+
595+
while ((pos = string.find(separator.value(), pos)) !=
596+
std::string::npos) {
555597
count++;
556598
if (max >= 0 && count > max) {
557599
break;
558600
} else {
559601
splits.emplace_back(string.substr(prev_pos, pos - prev_pos));
560602
}
561-
pos += separator.size();
603+
pos += separator.value().size();
562604
prev_pos = pos;
563605
}
564606
splits.emplace_back(string.substr(prev_pos, string.size() - prev_pos));

0 commit comments

Comments
 (0)
Please sign in to comment.