From d8404da85ed7bc2852edfdf83152a8611d95b406 Mon Sep 17 00:00:00 2001 From: Martin Czygan Date: Sun, 4 Jul 2021 10:52:47 +0200 Subject: zipkey: add batch test --- skate/zipkey/batch.go | 10 +++++--- skate/zipkey/batch_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) (limited to 'skate') diff --git a/skate/zipkey/batch.go b/skate/zipkey/batch.go index d81897d..c31909c 100644 --- a/skate/zipkey/batch.go +++ b/skate/zipkey/batch.go @@ -14,6 +14,7 @@ type Batcher struct { queue chan []*Group wg sync.WaitGroup err error + closing bool } // NewBatcher set ups a new Batcher. @@ -32,6 +33,7 @@ func NewBatcher(gf groupFunc) *Batcher { } func (b *Batcher) Close() error { + b.closing = true g := make([]*Group, len(b.batch)) copy(g, b.batch) b.queue <- g @@ -41,8 +43,11 @@ func (b *Batcher) Close() error { return b.err } -// GroupFunc implement the groupFunc type. +// GroupFunc implement the groupFunc type. Not thread safe. func (b *Batcher) GroupFunc(g *Group) error { + if b.closing { + panic("cannot call GroupFunc after Close") + } b.batch = append(b.batch, g) if len(b.batch) == b.Size { g := make([]*Group, len(b.batch)) @@ -59,8 +64,7 @@ func (b *Batcher) worker(queue chan []*Group, wg *sync.WaitGroup) { OUTER: for batch := range queue { for _, g := range batch { - err := b.gf(g) - if err != nil { + if err := b.gf(g); err != nil { b.err = err break OUTER } diff --git a/skate/zipkey/batch_test.go b/skate/zipkey/batch_test.go index 7c6a48c..38a1307 100644 --- a/skate/zipkey/batch_test.go +++ b/skate/zipkey/batch_test.go @@ -3,7 +3,11 @@ package zipkey import ( "bytes" "encoding/json" + "fmt" + "io" + "reflect" "strings" + "sync" "testing" ) @@ -30,3 +34,56 @@ func TestBatcher(t *testing.T) { t.Fatalf("got %v, want %v", got, want) } } + +func TestBatcherLarge(t *testing.T) { + var ( + N = 1000000 + numWorkers = 24 + size = 7000 + // We share a single writer across threads, so we need to guard each + // write. TODO: measure performance impact. + mu sync.Mutex + buf bytes.Buffer + f = func(g *Group) error { + var v string + if reflect.DeepEqual(g.G0, g.G1) { + v = "1" + } else { + v = "0" + } + mu.Lock() + defer mu.Unlock() + if _, err := io.WriteString(&buf, v); err != nil { + return err + } + return nil + } + b = NewBatcher(groupFunc(f)) + ) + b.Size = size + b.NumWorkers = numWorkers + for i := 0; i < N; i++ { + var u, v string + if i%2 == 0 { + u, v = "a", "b" + } else { + u, v = "a", "a" + } + g := &Group{ + Key: fmt.Sprintf("%d", i), + G0: []string{u}, + G1: []string{v}, + } + if err := b.GroupFunc(g); err != nil { + t.Fatalf("unexpected err from gf: %v", err) + } + } + if err := b.Close(); err != nil { + t.Fatalf("unexpected err from close: %v", err) + } + got := buf.String() + count0, count1 := strings.Count(got, "0"), strings.Count(got, "1") + if count1 != N/2 { + t.Fatalf("got %v, want %v (count0=%v, buf=%s)", count1, N/2, count0, buf.String()) + } +} -- cgit v1.2.3