diff --git a/pkg/sources/s3/checkpointer.go b/pkg/sources/s3/checkpointer.go index 4e121f72d680..a6df6586c97f 100644 --- a/pkg/sources/s3/checkpointer.go +++ b/pkg/sources/s3/checkpointer.go @@ -70,7 +70,7 @@ const defaultMaxObjectsPerPage = 1000 // NewCheckpointer creates a new checkpointer for S3 scanning operations. // The progress provides the underlying mechanism for persisting scan state. -func NewCheckpointer(ctx context.Context, progress *sources.Progress) *Checkpointer { +func NewCheckpointer(ctx context.Context, progress *sources.Progress, isUnitScan bool) *Checkpointer { ctx.Logger().Info("Creating checkpointer") return &Checkpointer{ @@ -78,6 +78,7 @@ func NewCheckpointer(ctx context.Context, progress *sources.Progress) *Checkpoin completedObjects: make([]bool, defaultMaxObjectsPerPage), completionOrder: make([]int, 0, defaultMaxObjectsPerPage), progress: progress, + isUnitScan: isUnitScan, } } @@ -192,6 +193,10 @@ func (p *Checkpointer) UpdateObjectCompletion( if checkpointIdx < 0 { return nil // No completed objects yet } + if checkpointIdx >= len(pageContents) { + // this should never happen + return fmt.Errorf("checkpoint index %d exceeds page contents size %d", checkpointIdx, len(pageContents)) + } obj := pageContents[checkpointIdx] return p.updateCheckpoint(bucket, role, *obj.Key) @@ -229,11 +234,3 @@ func (p *Checkpointer) updateCheckpoint(bucket string, role string, lastKey stri ) return nil } - -// SetIsUnitScan sets whether the checkpointer is operating in unit scan mode. -func (p *Checkpointer) SetIsUnitScan(isUnitScan bool) { - p.mu.Lock() - defer p.mu.Unlock() - - p.isUnitScan = isUnitScan -} diff --git a/pkg/sources/s3/checkpointer_test.go b/pkg/sources/s3/checkpointer_test.go index 2f58ae33f138..0d9082901235 100644 --- a/pkg/sources/s3/checkpointer_test.go +++ b/pkg/sources/s3/checkpointer_test.go @@ -19,7 +19,7 @@ func TestCheckpointerResumption(t *testing.T) { // First scan - process 6 objects then interrupt. initialProgress := &sources.Progress{} - tracker := NewCheckpointer(ctx, initialProgress) + tracker := NewCheckpointer(ctx, initialProgress, false) firstPage := &s3.ListObjectsV2Output{ Contents: make([]s3types.Object, 12), // Total of 12 objects @@ -42,7 +42,7 @@ func TestCheckpointerResumption(t *testing.T) { assert.Equal(t, "key-5", resumeInfo.StartAfter) // Resume scan with existing progress. - resumeTracker := NewCheckpointer(ctx, initialProgress) + resumeTracker := NewCheckpointer(ctx, initialProgress, false) resumePage := &s3.ListObjectsV2Output{ Contents: firstPage.Contents[6:], // Remaining 6 objects @@ -66,7 +66,7 @@ func TestCheckpointerResumptionWithRole(t *testing.T) { // First scan - process 6 objects then interrupt. initialProgress := &sources.Progress{} - tracker := NewCheckpointer(ctx, initialProgress) + tracker := NewCheckpointer(ctx, initialProgress, false) role := "test-role" firstPage := &s3.ListObjectsV2Output{ @@ -91,7 +91,7 @@ func TestCheckpointerResumptionWithRole(t *testing.T) { assert.Equal(t, role, resumeInfo.Role) // Resume scan with existing progress. - resumeTracker := NewCheckpointer(ctx, initialProgress) + resumeTracker := NewCheckpointer(ctx, initialProgress, false) resumePage := &s3.ListObjectsV2Output{ Contents: firstPage.Contents[6:], // Remaining 6 objects @@ -124,7 +124,7 @@ func TestCheckpointerReset(t *testing.T) { ctx := context.Background() progress := new(sources.Progress) - tracker := NewCheckpointer(ctx, progress) + tracker := NewCheckpointer(ctx, progress, false) tracker.completedObjects[1] = true tracker.completedObjects[2] = true @@ -441,8 +441,7 @@ func TestCheckpointerUpdateWithRole(t *testing.T) { func TestCheckpointerUpdateUnitScan(t *testing.T) { ctx := context.Background() progress := new(sources.Progress) - tracker := NewCheckpointer(ctx, progress) - tracker.SetIsUnitScan(true) + tracker := NewCheckpointer(ctx, progress, true) page := &s3.ListObjectsV2Output{ Contents: make([]s3types.Object, 3), @@ -528,7 +527,7 @@ func TestComplete(t *testing.T) { EncodedResumeInfo: tt.initialState.resumeInfo, Message: tt.initialState.message, } - tracker := NewCheckpointer(ctx, progress) + tracker := NewCheckpointer(ctx, progress, false) err := tracker.Complete(ctx, tt.completeMessage) assert.NoError(t, err) diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index f2709b14391e..2c19d939bb35 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -48,7 +48,6 @@ type Source struct { concurrency int conn *sourcespb.S3 - checkpointer *Checkpointer sources.Progress metricsCollector metricsCollector @@ -95,7 +94,6 @@ func (s *Source) Init( } s.conn = &conn - s.checkpointer = NewCheckpointer(ctx, &s.Progress) s.metricsCollector = metricsInstance s.setMaxObjectSize(conn.GetMaxObjectSize()) @@ -299,7 +297,8 @@ func (s *Source) scanBuckets( } var totalObjectCount uint64 - pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan) + checkpointer := NewCheckpointer(ctx, &s.Progress, false) + pos := determineResumePosition(ctx, checkpointer, bucketsToScan) switch { case pos.isNewScan: ctx.Logger().Info("Starting new scan from beginning") @@ -340,7 +339,7 @@ func (s *Source) scanBuckets( ) } - objectCount := s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter) + objectCount := s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter, checkpointer) totalObjectCount += objectCount } @@ -359,6 +358,7 @@ func (s *Source) scanBucket( bucket string, reporter sources.ChunkReporter, startAfter *string, + checkpointer *Checkpointer, ) uint64 { s.metricsCollector.RecordBucketForRole(role) @@ -412,7 +412,7 @@ func (s *Source) scanBucket( errorCount: &errorCount, objectCount: &objectCount, } - s.pageChunker(ctx, pageMetadata, processingState, reporter) + s.pageChunker(ctx, pageMetadata, processingState, reporter, checkpointer) pageNumber++ } @@ -458,8 +458,9 @@ func (s *Source) pageChunker( metadata pageMetadata, state processingState, reporter sources.ChunkReporter, + checkpointer *Checkpointer, ) { - s.checkpointer.Reset() // Reset the checkpointer for each PAGE + checkpointer.Reset() // Reset the checkpointer for each PAGE ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber) for objIdx, obj := range metadata.page.Contents { ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size) @@ -471,7 +472,7 @@ func (s *Source) pageChunker( if obj.StorageClass == s3types.ObjectStorageClassGlacier || obj.StorageClass == s3types.ObjectStorageClassGlacierIr { ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", obj.StorageClass) s.metricsCollector.RecordObjectSkipped(metadata.bucket, "storage_class", float64(*obj.Size)) - if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { + if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for glacier object") } continue @@ -481,7 +482,7 @@ func (s *Source) pageChunker( if *obj.Size > s.maxObjectSize { ctx.Logger().V(5).Info("Skipping large file", "max_object_size", s.maxObjectSize) s.metricsCollector.RecordObjectSkipped(metadata.bucket, "size_limit", float64(*obj.Size)) - if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { + if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for large file") } continue @@ -491,7 +492,7 @@ func (s *Source) pageChunker( if *obj.Size == 0 { ctx.Logger().V(5).Info("Skipping empty file") s.metricsCollector.RecordObjectSkipped(metadata.bucket, "empty_file", 0) - if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { + if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for empty file") } continue @@ -501,7 +502,7 @@ func (s *Source) pageChunker( if common.SkipFile(*obj.Key) { ctx.Logger().V(5).Info("Skipping file with incompatible extension") s.metricsCollector.RecordObjectSkipped(metadata.bucket, "incompatible_extension", float64(*obj.Size)) - if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { + if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for incompatible file") } continue @@ -613,7 +614,7 @@ func (s *Source) pageChunker( state.errorCount.Store(prefix, 0) } // Update progress after successful processing. - if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { + if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil { ctx.Logger().Error(err, "could not update progress for scanned object") } s.metricsCollector.RecordObjectScanned(metadata.bucket, float64(*obj.Size)) @@ -744,7 +745,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte return fmt.Errorf("could not create s3 client for bucket %s and role %s: %w", s3unit.Bucket, s3unit.Role, err) } - s.checkpointer.SetIsUnitScan(true) + checkpointer := NewCheckpointer(ctx, &s.Progress, true) var startAfterPtr *string startAfter := s.Progress.GetEncodedResumeInfoFor(unitID) @@ -757,7 +758,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte startAfterPtr = &startAfter } defer s.Progress.ClearEncodedResumeInfoFor(unitID) - s.scanBucket(ctx, defaultClient, s3unit.Role, s3unit.Bucket, reporter, startAfterPtr) + s.scanBucket(ctx, defaultClient, s3unit.Role, s3unit.Bucket, reporter, startAfterPtr, checkpointer) return nil }