@@ -526,88 +526,100 @@ func Copy(ctx context.Context, u *dburl.URL, stdout, stderr func() io.Writer, ro
526526 return d .Copy (ctx , db , rows , table )
527527}
528528
529- // CopyWithInsert builds a copy handler based on insert.
529+ // CopyWithInsert builds a typical copy handler based on insert.
530530func CopyWithInsert (placeholder func (int ) string ) func (ctx context.Context , db * sql.DB , rows * sql.Rows , table string ) (int64 , error ) {
531531 if placeholder == nil {
532532 placeholder = func (n int ) string { return fmt .Sprintf ("$%d" , n ) }
533533 }
534534 return func (ctx context.Context , db * sql.DB , rows * sql.Rows , table string ) (int64 , error ) {
535- columns , err := rows .Columns ()
536- if err != nil {
537- return 0 , fmt .Errorf ("failed to fetch source rows columns: %w" , err )
538- }
539- clen := len (columns )
540- query := table
541- if ! strings .HasPrefix (strings .ToLower (query ), "insert into" ) {
542- leftParen := strings .IndexRune (table , '(' )
543- if leftParen == - 1 {
544- colRows , err := db .QueryContext (ctx , "SELECT * FROM " + table + " WHERE 1=0" )
545- if err != nil {
546- return 0 , fmt .Errorf ("failed to execute query to determine target table columns: %w" , err )
547- }
548- columns , err := colRows .Columns ()
549- _ = colRows .Close ()
550- if err != nil {
551- return 0 , fmt .Errorf ("failed to fetch target table columns: %w" , err )
552- }
553- table += "(" + strings .Join (columns , ", " ) + ")"
535+ return FlexibleCopyWithInsert (ctx , db , rows , table , placeholder , true )
536+ }
537+ }
538+
539+ func FlexibleCopyWithInsert (ctx context.Context , db * sql.DB , rows * sql.Rows , table string , placeholder func (int ) string , withTransaction bool ) (int64 , error ) {
540+ columns , err := rows .Columns ()
541+ if err != nil {
542+ return 0 , fmt .Errorf ("failed to fetch source rows columns: %w" , err )
543+ }
544+ clen := len (columns )
545+ query := table
546+ if ! strings .HasPrefix (strings .ToLower (query ), "insert into" ) {
547+ leftParen := strings .IndexRune (table , '(' )
548+ if leftParen == - 1 {
549+ colRows , err := db .QueryContext (ctx , "SELECT * FROM " + table + " WHERE 1=0" )
550+ if err != nil {
551+ return 0 , fmt .Errorf ("failed to execute query to determine target table columns: %w" , err )
554552 }
555- // TODO if the db supports multiple rows per insert, create batches of 100 rows
556- placeholders := make ([] string , clen )
557- for i := 0 ; i < clen ; i ++ {
558- placeholders [ i ] = placeholder ( i + 1 )
553+ columns , err := colRows . Columns ()
554+ _ = colRows . Close ( )
555+ if err != nil {
556+ return 0 , fmt . Errorf ( "failed to fetch target table columns: %w" , err )
559557 }
560- query = "INSERT INTO " + table + " VALUES (" + strings .Join (placeholders , ", " ) + ")"
558+ table += "(" + strings .Join (columns , ", " ) + ")"
559+ }
560+ // TODO if the db supports multiple rows per insert, create batches of 100 rows
561+ placeholders := make ([]string , clen )
562+ for i := 0 ; i < clen ; i ++ {
563+ placeholders [i ] = placeholder (i + 1 )
561564 }
562- tx , err := db .BeginTx (ctx , nil )
565+ query = "INSERT INTO " + table + " VALUES (" + strings .Join (placeholders , ", " ) + ")"
566+ }
567+ var stmt * sql.Stmt
568+ var tx * sql.Tx
569+ if withTransaction {
570+ tx , err = db .BeginTx (ctx , nil )
563571 if err != nil {
564572 return 0 , fmt .Errorf ("failed to begin transaction: %w" , err )
565573 }
566- stmt , err := tx .PrepareContext (ctx , query )
574+ stmt , err = tx .PrepareContext (ctx , query )
575+ } else {
576+ stmt , err = db .PrepareContext (ctx , query )
577+ }
578+ if err != nil {
579+ return 0 , fmt .Errorf ("failed to prepare insert query: %w" , err )
580+ }
581+ defer stmt .Close ()
582+ columnTypes , err := rows .ColumnTypes ()
583+ if err != nil {
584+ return 0 , fmt .Errorf ("failed to fetch source column types: %w" , err )
585+ }
586+ values := make ([]interface {}, clen )
587+ valueRefs := make ([]reflect.Value , clen )
588+ actuals := make ([]interface {}, clen )
589+ for i := 0 ; i < len (columnTypes ); i ++ {
590+ valueRefs [i ] = reflect .New (columnTypes [i ].ScanType ())
591+ values [i ] = valueRefs [i ].Interface ()
592+ }
593+ var n int64
594+ for rows .Next () {
595+ err = rows .Scan (values ... )
567596 if err != nil {
568- return 0 , fmt .Errorf ("failed to prepare insert query : %w" , err )
597+ return n , fmt .Errorf ("failed to scan row : %w" , err )
569598 }
570- defer stmt . Close ()
571- columnTypes , err := rows . ColumnTypes ()
572- if err != nil {
573- return 0 , fmt . Errorf ( "failed to fetch source column types: %w" , err )
599+ //We can't use values... in Exec() below, because some drivers
600+ //don't accept pointer to an argument instead of the arg itself.
601+ for i := range values {
602+ actuals [ i ] = valueRefs [ i ]. Elem (). Interface ( )
574603 }
575- values := make ([]interface {}, clen )
576- valueRefs := make ([]reflect.Value , clen )
577- actuals := make ([]interface {}, clen )
578- for i := 0 ; i < len (columnTypes ); i ++ {
579- valueRefs [i ] = reflect .New (columnTypes [i ].ScanType ())
580- values [i ] = valueRefs [i ].Interface ()
604+ res , err := stmt .ExecContext (ctx , actuals ... )
605+ if err != nil {
606+ return n , fmt .Errorf ("failed to exec insert: %w" , err )
581607 }
582- var n int64
583- for rows .Next () {
584- err = rows .Scan (values ... )
585- if err != nil {
586- return n , fmt .Errorf ("failed to scan row: %w" , err )
587- }
588- //We can't use values... in Exec() below, because some drivers
589- //don't accept pointer to an argument instead of the arg itself.
590- for i := range values {
591- actuals [i ] = valueRefs [i ].Elem ().Interface ()
592- }
593- res , err := stmt .ExecContext (ctx , actuals ... )
594- if err != nil {
595- return n , fmt .Errorf ("failed to exec insert: %w" , err )
596- }
597- rn , err := res .RowsAffected ()
598- if err != nil {
599- return n , fmt .Errorf ("failed to check rows affected: %w" , err )
600- }
601- n += rn
608+ rn , err := res .RowsAffected ()
609+ if err != nil {
610+ return n , fmt .Errorf ("failed to check rows affected: %w" , err )
602611 }
603- // TODO if using batches, flush the last batch,
604- // TODO prepare another statement and count remaining rows
612+ n += rn
613+ }
614+ // TODO if using batches, flush the last batch,
615+ // TODO prepare another statement and count remaining rows
616+ if tx != nil {
605617 err = tx .Commit ()
606618 if err != nil {
607619 return n , fmt .Errorf ("failed to commit transaction: %w" , err )
608620 }
609- return n , rows .Err ()
610621 }
622+ return n , rows .Err ()
611623}
612624
613625func init () {
0 commit comments