Skip to content

Commit ed43383

Browse files
authored
Introduce PyUtf8Str and fix(sqlite): validate surrogates in SQL statements (#5969)
* fix(sqlite): validate surrogates in SQL statements * Add `PyUtf8Str` wrapper for safe conversion
1 parent fd35c7a commit ed43383

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

Lib/test/test_sqlite3/test_regression.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,6 @@ def test_null_character(self):
343343
self.assertRaisesRegex(sqlite.ProgrammingError, "null char",
344344
cur.execute, query)
345345

346-
# TODO: RUSTPYTHON
347-
@unittest.expectedFailure
348346
def test_surrogates(self):
349347
con = sqlite.connect(":memory:")
350348
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")

stdlib/src/sqlite.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ mod _sqlite {
844844
type Args = (PyStrRef,);
845845

846846
fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
847-
if let Some(stmt) = Statement::new(zelf, &args.0, vm)? {
847+
if let Some(stmt) = Statement::new(zelf, args.0, vm)? {
848848
Ok(stmt.into_ref(&vm.ctx).into())
849849
} else {
850850
Ok(vm.ctx.none())
@@ -1480,7 +1480,7 @@ mod _sqlite {
14801480
stmt.lock().reset();
14811481
}
14821482

1483-
let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else {
1483+
let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else {
14841484
drop(inner);
14851485
return Ok(zelf);
14861486
};
@@ -1552,7 +1552,7 @@ mod _sqlite {
15521552
stmt.lock().reset();
15531553
}
15541554

1555-
let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else {
1555+
let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else {
15561556
drop(inner);
15571557
return Ok(zelf);
15581558
};
@@ -2291,9 +2291,10 @@ mod _sqlite {
22912291
impl Statement {
22922292
fn new(
22932293
connection: &Connection,
2294-
sql: &PyStr,
2294+
sql: PyStrRef,
22952295
vm: &VirtualMachine,
22962296
) -> PyResult<Option<Self>> {
2297+
let sql = sql.try_into_utf8(vm)?;
22972298
let sql_cstr = sql.to_cstring(vm)?;
22982299
let sql_len = sql.byte_len() + 1;
22992300

vm/src/builtins/str.rs

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ use rustpython_common::{
3737
str::DeduceStrKind,
3838
wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk},
3939
};
40-
use std::sync::LazyLock;
4140
use std::{borrow::Cow, char, fmt, ops::Range};
41+
use std::{mem, sync::LazyLock};
4242
use unic_ucd_bidi::BidiClass;
4343
use unic_ucd_category::GeneralCategory;
4444
use unic_ucd_ident::{is_xid_continue, is_xid_start};
@@ -80,6 +80,30 @@ impl fmt::Debug for PyStr {
8080
}
8181
}
8282

83+
#[repr(transparent)]
84+
#[derive(Debug)]
85+
pub struct PyUtf8Str(PyStr);
86+
87+
// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str
88+
impl std::ops::Deref for PyUtf8Str {
89+
type Target = PyStr;
90+
fn deref(&self) -> &Self::Target {
91+
&self.0
92+
}
93+
}
94+
95+
impl PyUtf8Str {
96+
/// Returns the underlying string slice.
97+
pub fn as_str(&self) -> &str {
98+
debug_assert!(
99+
self.0.is_utf8(),
100+
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
101+
);
102+
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
103+
unsafe { self.0.to_str().unwrap_unchecked() }
104+
}
105+
}
106+
83107
impl AsRef<str> for PyStr {
84108
#[track_caller] // <- can remove this once it doesn't panic
85109
fn as_ref(&self) -> &str {
@@ -433,21 +457,29 @@ impl PyStr {
433457
self.data.as_str()
434458
}
435459

436-
pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
437-
self.to_str().ok_or_else(|| {
460+
fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
461+
if self.is_utf8() {
462+
Ok(())
463+
} else {
438464
let start = self
439465
.as_wtf8()
440466
.code_points()
441467
.position(|c| c.to_char().is_none())
442468
.unwrap();
443-
vm.new_unicode_encode_error_real(
469+
Err(vm.new_unicode_encode_error_real(
444470
identifier!(vm, utf_8).to_owned(),
445471
vm.ctx.new_str(self.data.clone()),
446472
start,
447473
start + 1,
448474
vm.ctx.new_str("surrogates not allowed"),
449-
)
450-
})
475+
))
476+
}
477+
}
478+
479+
pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
480+
self.ensure_valid_utf8(vm)?;
481+
// SAFETY: ensure_valid_utf8 passed, so unwrap is safe.
482+
Ok(unsafe { self.to_str().unwrap_unchecked() })
451483
}
452484

453485
pub fn to_string_lossy(&self) -> Cow<'_, str> {
@@ -1486,6 +1518,11 @@ impl PyStrRef {
14861518
s.push_wtf8(other);
14871519
*self = PyStr::from(s).into_ref(&vm.ctx);
14881520
}
1521+
1522+
pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult<PyRef<PyUtf8Str>> {
1523+
self.ensure_valid_utf8(vm)?;
1524+
Ok(unsafe { mem::transmute::<PyRef<PyStr>, PyRef<PyUtf8Str>>(self) })
1525+
}
14891526
}
14901527

14911528
impl Representable for PyStr {

0 commit comments

Comments
 (0)