19
19
20
20
use std:: any:: Any ;
21
21
use std:: collections:: HashSet ;
22
- use std:: fmt;
23
- use std:: fmt:: Debug ;
22
+ use std:: fmt:: { self , Debug } ;
24
23
use std:: sync:: Arc ;
25
24
26
- use arrow_array:: RecordBatch ;
27
- use datafusion_common:: { exec_err, not_impl_err, DataFusionError , FileType } ;
28
- use datafusion_execution:: TaskContext ;
29
- use datafusion_physical_expr:: { PhysicalExpr , PhysicalSortRequirement } ;
30
-
31
- use bytes:: { Buf , Bytes } ;
32
- use datafusion_physical_plan:: metrics:: MetricsSet ;
33
- use futures:: stream:: BoxStream ;
34
- use futures:: { pin_mut, Stream , StreamExt , TryStreamExt } ;
35
- use object_store:: { delimited:: newline_delimited_stream, ObjectMeta , ObjectStore } ;
36
-
37
25
use super :: write:: orchestration:: stateless_multipart_put;
38
26
use super :: { FileFormat , DEFAULT_SCHEMA_INFER_MAX_RECORD } ;
39
27
use crate :: datasource:: file_format:: file_compression_type:: FileCompressionType ;
@@ -47,11 +35,20 @@ use crate::physical_plan::insert::{DataSink, FileSinkExec};
47
35
use crate :: physical_plan:: { DisplayAs , DisplayFormatType , Statistics } ;
48
36
use crate :: physical_plan:: { ExecutionPlan , SendableRecordBatchStream } ;
49
37
38
+ use arrow:: array:: RecordBatch ;
50
39
use arrow:: csv:: WriterBuilder ;
51
40
use arrow:: datatypes:: { DataType , Field , Fields , Schema } ;
52
41
use arrow:: { self , datatypes:: SchemaRef } ;
42
+ use datafusion_common:: { exec_err, not_impl_err, DataFusionError , FileType } ;
43
+ use datafusion_execution:: TaskContext ;
44
+ use datafusion_physical_expr:: { PhysicalExpr , PhysicalSortRequirement } ;
45
+ use datafusion_physical_plan:: metrics:: MetricsSet ;
53
46
54
47
use async_trait:: async_trait;
48
+ use bytes:: { Buf , Bytes } ;
49
+ use futures:: stream:: BoxStream ;
50
+ use futures:: { pin_mut, Stream , StreamExt , TryStreamExt } ;
51
+ use object_store:: { delimited:: newline_delimited_stream, ObjectMeta , ObjectStore } ;
55
52
56
53
/// Character Separated Value `FileFormat` implementation.
57
54
#[ derive( Debug ) ]
@@ -400,8 +397,6 @@ impl Default for CsvSerializer {
400
397
pub struct CsvSerializer {
401
398
// CSV writer builder
402
399
builder : WriterBuilder ,
403
- // Inner buffer for avoiding reallocation
404
- buffer : Vec < u8 > ,
405
400
// Flag to indicate whether there will be a header
406
401
header : bool ,
407
402
}
@@ -412,7 +407,6 @@ impl CsvSerializer {
412
407
Self {
413
408
builder : WriterBuilder :: new ( ) ,
414
409
header : true ,
415
- buffer : Vec :: with_capacity ( 4096 ) ,
416
410
}
417
411
}
418
412
@@ -431,21 +425,14 @@ impl CsvSerializer {
431
425
432
426
#[ async_trait]
433
427
impl BatchSerializer for CsvSerializer {
434
- async fn serialize ( & mut self , batch : RecordBatch ) -> Result < Bytes > {
428
+ async fn serialize ( & self , batch : RecordBatch , initial : bool ) -> Result < Bytes > {
429
+ let mut buffer = Vec :: with_capacity ( 4096 ) ;
435
430
let builder = self . builder . clone ( ) ;
436
- let mut writer = builder. with_header ( self . header ) . build ( & mut self . buffer ) ;
431
+ let header = self . header && initial;
432
+ let mut writer = builder. with_header ( header) . build ( & mut buffer) ;
437
433
writer. write ( & batch) ?;
438
434
drop ( writer) ;
439
- self . header = false ;
440
- Ok ( Bytes :: from ( self . buffer . drain ( ..) . collect :: < Vec < u8 > > ( ) ) )
441
- }
442
-
443
- fn duplicate ( & mut self ) -> Result < Box < dyn BatchSerializer > > {
444
- let new_self = CsvSerializer :: new ( )
445
- . with_builder ( self . builder . clone ( ) )
446
- . with_header ( self . header ) ;
447
- self . header = false ;
448
- Ok ( Box :: new ( new_self) )
435
+ Ok ( Bytes :: from ( buffer) )
449
436
}
450
437
}
451
438
@@ -488,13 +475,11 @@ impl CsvSink {
488
475
let builder_clone = builder. clone ( ) ;
489
476
let options_clone = writer_options. clone ( ) ;
490
477
let get_serializer = move || {
491
- let inner_clone = builder_clone. clone ( ) ;
492
- let serializer: Box < dyn BatchSerializer > = Box :: new (
478
+ Arc :: new (
493
479
CsvSerializer :: new ( )
494
- . with_builder ( inner_clone )
480
+ . with_builder ( builder_clone . clone ( ) )
495
481
. with_header ( options_clone. writer_options . header ( ) ) ,
496
- ) ;
497
- serializer
482
+ ) as _
498
483
} ;
499
484
500
485
stateless_multipart_put (
@@ -541,15 +526,15 @@ mod tests {
541
526
use crate :: physical_plan:: collect;
542
527
use crate :: prelude:: { CsvReadOptions , SessionConfig , SessionContext } ;
543
528
use crate :: test_util:: arrow_test_data;
529
+
544
530
use arrow:: compute:: concat_batches;
545
- use bytes:: Bytes ;
546
- use chrono:: DateTime ;
547
531
use datafusion_common:: cast:: as_string_array;
548
- use datafusion_common:: internal_err;
549
532
use datafusion_common:: stats:: Precision ;
550
- use datafusion_common:: FileType ;
551
- use datafusion_common:: GetExt ;
533
+ use datafusion_common:: { internal_err, FileType , GetExt } ;
552
534
use datafusion_expr:: { col, lit} ;
535
+
536
+ use bytes:: Bytes ;
537
+ use chrono:: DateTime ;
553
538
use futures:: StreamExt ;
554
539
use object_store:: local:: LocalFileSystem ;
555
540
use object_store:: path:: Path ;
@@ -836,8 +821,8 @@ mod tests {
836
821
. collect ( )
837
822
. await ?;
838
823
let batch = concat_batches ( & batches[ 0 ] . schema ( ) , & batches) ?;
839
- let mut serializer = CsvSerializer :: new ( ) ;
840
- let bytes = serializer. serialize ( batch) . await ?;
824
+ let serializer = CsvSerializer :: new ( ) ;
825
+ let bytes = serializer. serialize ( batch, true ) . await ?;
841
826
assert_eq ! (
842
827
"c2,c3\n 2,1\n 5,-40\n 1,29\n 1,-85\n 5,-82\n 4,-111\n 3,104\n 3,13\n 1,38\n 4,-38\n " ,
843
828
String :: from_utf8( bytes. into( ) ) . unwrap( )
@@ -860,8 +845,8 @@ mod tests {
860
845
. collect ( )
861
846
. await ?;
862
847
let batch = concat_batches ( & batches[ 0 ] . schema ( ) , & batches) ?;
863
- let mut serializer = CsvSerializer :: new ( ) . with_header ( false ) ;
864
- let bytes = serializer. serialize ( batch) . await ?;
848
+ let serializer = CsvSerializer :: new ( ) . with_header ( false ) ;
849
+ let bytes = serializer. serialize ( batch, true ) . await ?;
865
850
assert_eq ! (
866
851
"2,1\n 5,-40\n 1,29\n 1,-85\n 5,-82\n 4,-111\n 3,104\n 3,13\n 1,38\n 4,-38\n " ,
867
852
String :: from_utf8( bytes. into( ) ) . unwrap( )
0 commit comments