Skip to content

Commit 155f197

Browse files
committed
Handle escape characters correctly
1 parent 53e0b4f commit 155f197

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

postgresql_csv_loader/csv_loader.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ class CsvLoader(object):
1515

1616
DEFAULT_DELIMITER = ','
1717
DEFAULT_QUOTE_CHAR = '"'
18-
DEFAULT_ESCAPE_CHAR = '"'
18+
DEFAULT_ESCAPE_CHAR = None
1919
DEFAULT_TABLE_PREFIX = "csv_"
20+
DEFAULT_DOUBLE_QUOTE = True
2021
DEFAULT_DATA_TYPE = "varchar"
2122

2223
CREATE_STMT = "CREATE TABLE {} ({});"
@@ -59,10 +60,13 @@ def load_data(self, file_path, delimiter=DEFAULT_DELIMITER, quote_char=DEFAULT_Q
5960
:param create_table: if True, table will be created
6061
:param encoding file encoding
6162
"""
63+
# doublequote=True by default
64+
# don't define escape char if it's the same as quote char
65+
escape_char = None if (escape_char == quote_char) else escape_char
66+
6267
original_headers = self._read_headers(file_path, delimiter, quote_char, escape_char, encoding)
6368
headers = self._normalize_headers(original_headers)
6469
table_name = self._generate_table_name(file_path)
65-
6670
logging.getLogger('CsvLoader').info('Connecting to database "{}"...'.format(self._database_name))
6771
connection = connect(dbname=self._database_name, user=self._user, password= self._password,
6872
host=self._database_host, port=self._database_port)
@@ -148,8 +152,10 @@ def _copy_from_csv(self, connection, file_path, table_name, headers, delimiter,
148152
"""
149153
columns = ['"{}"'.format(column) for column in headers]
150154
columns_def = ",".join(columns)
155+
156+
copy_from_escape_char = escape_char or quote_char # use quote if escape is None
151157
command = self.COPY_STMT.format(table_name, columns_def, delimiter,
152-
quote_char, escape_char)
158+
quote_char, copy_from_escape_char)
153159
# https://www.postgresql.org/docs/current/static/sql-copy.html
154160

155161
cursor = connection.cursor()

test/resources/quoted_headers.csv

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"Respondent","Professional","Country"
2+
"1","Student","United States ""America"""
3+
2,Student,United Kingdom
4+
3,Professional developer,United Kingdom
5+
4,Professional developer,United States
6+
5,Professional developer,Switzerland

test/test_csv_loader.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,19 @@ class TestCsvLoader(unittest.TestCase):
1919
CSV_FILENAME_3 = "resources/illegal_column_names.csv"
2020
CSV_FILENAME_4 = "resources/polish_characters.csv"
2121
CSV_FILENAME_5 = "resources/weird_format.csv"
22+
CSV_FILENAME_6 = "resources/quoted_headers.csv"
2223
CSV_1_RECORD_COUNT = 30
2324
CSV_2_RECORD_COUNT = 5
2425
CSV_3_RECORD_COUNT = 5
2526
CSV_4_RECORD_COUNT = 1
2627
CSV_5_RECORD_COUNT = 1
28+
CSV_6_RECORD_COUNT = 5
2729
TABLE_NAME_1 = "csv_stackoverflow_survey_results_public_sample"
2830
TABLE_NAME_2 = "csv_simple_table"
2931
TABLE_NAME_3 = "csv_illegal_column_names"
3032
TABLE_NAME_4 = "csv_polish_characters"
3133
TABLE_NAME_5 = "csv_weird_format"
34+
TABLE_NAME_6 = "csv_quoted_headers"
3235

3336
SELECT_COUNT_STMT = "SELECT count(*) from {};"
3437
DROP_STMT = "DROP TABLE {};"
@@ -135,6 +138,15 @@ def test_load_weird_format(self):
135138
self._drop(self.TABLE_NAME_5)
136139
self.assertEqual(result, self.CSV_5_RECORD_COUNT)
137140

141+
def test_headers_quotes(self):
142+
loader = self._get_loader()
143+
loader.load_data(self.CSV_FILENAME_6)
144+
145+
result = self._check_count(self.TABLE_NAME_6)
146+
self._drop(self.TABLE_NAME_6)
147+
148+
self.assertEqual(result, self.CSV_6_RECORD_COUNT)
149+
138150
def _get_loader(self):
139151
return CsvLoader(self.database_host, self.database_port, self.database_name, self.database_user)
140152

0 commit comments

Comments
 (0)