13
13
DATABASE_URLS = [url .strip () for url in os .environ ["TEST_DATABASE_URLS" ].split ("," )]
14
14
15
15
16
+ class MyEpochType (sqlalchemy .types .TypeDecorator ):
17
+ impl = sqlalchemy .Integer
18
+
19
+ epoch = datetime .date (1970 , 1 , 1 )
20
+
21
+ def process_bind_param (self , value , dialect ):
22
+ return (value - self .epoch ).days
23
+
24
+ def process_result_value (self , value , dialect ):
25
+ return self .epoch + datetime .timedelta (days = value )
26
+
27
+
16
28
metadata = sqlalchemy .MetaData ()
17
29
18
30
notes = sqlalchemy .Table (
23
35
sqlalchemy .Column ("completed" , sqlalchemy .Boolean ),
24
36
)
25
37
38
+ # Used to test DateTime
26
39
articles = sqlalchemy .Table (
27
40
"articles" ,
28
41
metadata ,
31
44
sqlalchemy .Column ("published" , sqlalchemy .DateTime ),
32
45
)
33
46
47
+ # Used to test JSON
34
48
session = sqlalchemy .Table (
35
49
"session" ,
36
50
metadata ,
37
51
sqlalchemy .Column ("id" , sqlalchemy .Integer , primary_key = True ),
38
52
sqlalchemy .Column ("data" , sqlalchemy .JSON ),
39
53
)
40
54
55
+ # Used to test custom column types
56
+ custom_date = sqlalchemy .Table (
57
+ "custom_date" ,
58
+ metadata ,
59
+ sqlalchemy .Column ("id" , sqlalchemy .Integer , primary_key = True ),
60
+ sqlalchemy .Column ("title" , sqlalchemy .String (length = 100 )),
61
+ sqlalchemy .Column ("published" , MyEpochType ),
62
+ )
63
+
41
64
42
65
@pytest .fixture (autouse = True , scope = "module" )
43
66
def create_test_database ():
@@ -226,7 +249,7 @@ async def test_transaction_rollback_low_level(database_url):
226
249
@async_adapter
227
250
async def test_datetime_field (database_url ):
228
251
"""
229
- Test DataTime fields , to ensure records are coerced to proper Python types.
252
+ Test DataTime columns , to ensure records are coerced to/from proper Python types.
230
253
"""
231
254
232
255
async with Database (database_url ) as database :
@@ -250,7 +273,7 @@ async def test_datetime_field(database_url):
250
273
@async_adapter
251
274
async def test_json_field (database_url ):
252
275
"""
253
- Test JSON fields , to ensure correct cross-database support.
276
+ Test JSON columns , to ensure correct cross-database support.
254
277
"""
255
278
256
279
async with Database (database_url ) as database :
@@ -265,3 +288,27 @@ async def test_json_field(database_url):
265
288
results = await database .fetch_all (query = query )
266
289
assert len (results ) == 1
267
290
assert results [0 ]["data" ] == {"text" : "hello" , "boolean" : True , "int" : 1 }
291
+
292
+
293
+ @pytest .mark .parametrize ("database_url" , DATABASE_URLS )
294
+ @async_adapter
295
+ async def test_custom_field (database_url ):
296
+ """
297
+ Test custom column types.
298
+ """
299
+
300
+ async with Database (database_url ) as database :
301
+ async with database .transaction (force_rollback = True ):
302
+ today = datetime .date .today ()
303
+
304
+ # execute()
305
+ query = custom_date .insert ()
306
+ values = {"title" : "Hello, world" , "published" : today }
307
+ await database .execute (query , values )
308
+
309
+ # fetch_all()
310
+ query = custom_date .select ()
311
+ results = await database .fetch_all (query = query )
312
+ assert len (results ) == 1
313
+ assert results [0 ]["title" ] == "Hello, world"
314
+ assert results [0 ]["published" ] == today
0 commit comments