Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions pkg/sources/s3/checkpointer.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,15 @@ const defaultMaxObjectsPerPage = 1000

// NewCheckpointer creates a new checkpointer for S3 scanning operations.
// The progress provides the underlying mechanism for persisting scan state.
func NewCheckpointer(ctx context.Context, progress *sources.Progress) *Checkpointer {
func NewCheckpointer(ctx context.Context, progress *sources.Progress, isUnitScan bool) *Checkpointer {
ctx.Logger().Info("Creating checkpointer")

return &Checkpointer{
// We are resuming if we have completed objects from a previous scan.
completedObjects: make([]bool, defaultMaxObjectsPerPage),
completionOrder: make([]int, 0, defaultMaxObjectsPerPage),
progress: progress,
isUnitScan: isUnitScan,
}
}

Expand Down Expand Up @@ -192,6 +193,10 @@ func (p *Checkpointer) UpdateObjectCompletion(
if checkpointIdx < 0 {
return nil // No completed objects yet
}
if checkpointIdx >= len(pageContents) {
// this should never happen
return fmt.Errorf("checkpoint index %d exceeds page contents size %d", checkpointIdx, len(pageContents))
}
obj := pageContents[checkpointIdx]

return p.updateCheckpoint(bucket, role, *obj.Key)
Expand Down Expand Up @@ -229,11 +234,3 @@ func (p *Checkpointer) updateCheckpoint(bucket string, role string, lastKey stri
)
return nil
}

// SetIsUnitScan sets whether the checkpointer is operating in unit scan mode.
func (p *Checkpointer) SetIsUnitScan(isUnitScan bool) {
p.mu.Lock()
defer p.mu.Unlock()

p.isUnitScan = isUnitScan
}
15 changes: 7 additions & 8 deletions pkg/sources/s3/checkpointer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestCheckpointerResumption(t *testing.T) {

// First scan - process 6 objects then interrupt.
initialProgress := &sources.Progress{}
tracker := NewCheckpointer(ctx, initialProgress)
tracker := NewCheckpointer(ctx, initialProgress, false)

firstPage := &s3.ListObjectsV2Output{
Contents: make([]s3types.Object, 12), // Total of 12 objects
Expand All @@ -42,7 +42,7 @@ func TestCheckpointerResumption(t *testing.T) {
assert.Equal(t, "key-5", resumeInfo.StartAfter)

// Resume scan with existing progress.
resumeTracker := NewCheckpointer(ctx, initialProgress)
resumeTracker := NewCheckpointer(ctx, initialProgress, false)

resumePage := &s3.ListObjectsV2Output{
Contents: firstPage.Contents[6:], // Remaining 6 objects
Expand All @@ -66,7 +66,7 @@ func TestCheckpointerResumptionWithRole(t *testing.T) {

// First scan - process 6 objects then interrupt.
initialProgress := &sources.Progress{}
tracker := NewCheckpointer(ctx, initialProgress)
tracker := NewCheckpointer(ctx, initialProgress, false)
role := "test-role"

firstPage := &s3.ListObjectsV2Output{
Expand All @@ -91,7 +91,7 @@ func TestCheckpointerResumptionWithRole(t *testing.T) {
assert.Equal(t, role, resumeInfo.Role)

// Resume scan with existing progress.
resumeTracker := NewCheckpointer(ctx, initialProgress)
resumeTracker := NewCheckpointer(ctx, initialProgress, false)

resumePage := &s3.ListObjectsV2Output{
Contents: firstPage.Contents[6:], // Remaining 6 objects
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestCheckpointerReset(t *testing.T) {

ctx := context.Background()
progress := new(sources.Progress)
tracker := NewCheckpointer(ctx, progress)
tracker := NewCheckpointer(ctx, progress, false)

tracker.completedObjects[1] = true
tracker.completedObjects[2] = true
Expand Down Expand Up @@ -441,8 +441,7 @@ func TestCheckpointerUpdateWithRole(t *testing.T) {
func TestCheckpointerUpdateUnitScan(t *testing.T) {
ctx := context.Background()
progress := new(sources.Progress)
tracker := NewCheckpointer(ctx, progress)
tracker.SetIsUnitScan(true)
tracker := NewCheckpointer(ctx, progress, true)

page := &s3.ListObjectsV2Output{
Contents: make([]s3types.Object, 3),
Expand Down Expand Up @@ -528,7 +527,7 @@ func TestComplete(t *testing.T) {
EncodedResumeInfo: tt.initialState.resumeInfo,
Message: tt.initialState.message,
}
tracker := NewCheckpointer(ctx, progress)
tracker := NewCheckpointer(ctx, progress, false)

err := tracker.Complete(ctx, tt.completeMessage)
assert.NoError(t, err)
Expand Down
27 changes: 14 additions & 13 deletions pkg/sources/s3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type Source struct {
concurrency int
conn *sourcespb.S3

checkpointer *Checkpointer
sources.Progress
metricsCollector metricsCollector

Expand Down Expand Up @@ -95,7 +94,6 @@ func (s *Source) Init(
}
s.conn = &conn

s.checkpointer = NewCheckpointer(ctx, &s.Progress)
s.metricsCollector = metricsInstance

s.setMaxObjectSize(conn.GetMaxObjectSize())
Expand Down Expand Up @@ -299,7 +297,8 @@ func (s *Source) scanBuckets(
}
var totalObjectCount uint64

pos := determineResumePosition(ctx, s.checkpointer, bucketsToScan)
checkpointer := NewCheckpointer(ctx, &s.Progress, false)
pos := determineResumePosition(ctx, checkpointer, bucketsToScan)
switch {
case pos.isNewScan:
ctx.Logger().Info("Starting new scan from beginning")
Expand Down Expand Up @@ -340,7 +339,7 @@ func (s *Source) scanBuckets(
)
}

objectCount := s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter)
objectCount := s.scanBucket(ctx, client, role, bucket, sources.ChanReporter{Ch: chunksChan}, startAfter, checkpointer)
totalObjectCount += objectCount
}

Expand All @@ -359,6 +358,7 @@ func (s *Source) scanBucket(
bucket string,
reporter sources.ChunkReporter,
startAfter *string,
checkpointer *Checkpointer,
) uint64 {
s.metricsCollector.RecordBucketForRole(role)

Expand Down Expand Up @@ -412,7 +412,7 @@ func (s *Source) scanBucket(
errorCount: &errorCount,
objectCount: &objectCount,
}
s.pageChunker(ctx, pageMetadata, processingState, reporter)
s.pageChunker(ctx, pageMetadata, processingState, reporter, checkpointer)

pageNumber++
}
Expand Down Expand Up @@ -458,8 +458,9 @@ func (s *Source) pageChunker(
metadata pageMetadata,
state processingState,
reporter sources.ChunkReporter,
checkpointer *Checkpointer,
) {
s.checkpointer.Reset() // Reset the checkpointer for each PAGE
checkpointer.Reset() // Reset the checkpointer for each PAGE
ctx = context.WithValues(ctx, "bucket", metadata.bucket, "page_number", metadata.pageNumber)
for objIdx, obj := range metadata.page.Contents {
ctx = context.WithValues(ctx, "key", *obj.Key, "size", *obj.Size)
Expand All @@ -471,7 +472,7 @@ func (s *Source) pageChunker(
if obj.StorageClass == s3types.ObjectStorageClassGlacier || obj.StorageClass == s3types.ObjectStorageClassGlacierIr {
ctx.Logger().V(5).Info("Skipping object in storage class", "storage_class", obj.StorageClass)
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "storage_class", float64(*obj.Size))
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for glacier object")
}
continue
Expand All @@ -481,7 +482,7 @@ func (s *Source) pageChunker(
if *obj.Size > s.maxObjectSize {
ctx.Logger().V(5).Info("Skipping large file", "max_object_size", s.maxObjectSize)
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "size_limit", float64(*obj.Size))
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for large file")
}
continue
Expand All @@ -491,7 +492,7 @@ func (s *Source) pageChunker(
if *obj.Size == 0 {
ctx.Logger().V(5).Info("Skipping empty file")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "empty_file", 0)
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for empty file")
}
continue
Expand All @@ -501,7 +502,7 @@ func (s *Source) pageChunker(
if common.SkipFile(*obj.Key) {
ctx.Logger().V(5).Info("Skipping file with incompatible extension")
s.metricsCollector.RecordObjectSkipped(metadata.bucket, "incompatible_extension", float64(*obj.Size))
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for incompatible file")
}
continue
Expand Down Expand Up @@ -613,7 +614,7 @@ func (s *Source) pageChunker(
state.errorCount.Store(prefix, 0)
}
// Update progress after successful processing.
if err := s.checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
if err := checkpointer.UpdateObjectCompletion(ctx, objIdx, metadata.bucket, metadata.role, metadata.page.Contents); err != nil {
ctx.Logger().Error(err, "could not update progress for scanned object")
}
s.metricsCollector.RecordObjectScanned(metadata.bucket, float64(*obj.Size))
Expand Down Expand Up @@ -744,7 +745,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
return fmt.Errorf("could not create s3 client for bucket %s and role %s: %w", s3unit.Bucket, s3unit.Role, err)
}

s.checkpointer.SetIsUnitScan(true)
checkpointer := NewCheckpointer(ctx, &s.Progress, true)

var startAfterPtr *string
startAfter := s.Progress.GetEncodedResumeInfoFor(unitID)
Expand All @@ -757,7 +758,7 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
startAfterPtr = &startAfter
}
defer s.Progress.ClearEncodedResumeInfoFor(unitID)
s.scanBucket(ctx, defaultClient, s3unit.Role, s3unit.Bucket, reporter, startAfterPtr)
s.scanBucket(ctx, defaultClient, s3unit.Role, s3unit.Bucket, reporter, startAfterPtr, checkpointer)
return nil
}

Expand Down
Loading