aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Czygan <martin.czygan@gmail.com>2021-07-04 10:52:47 +0200
committerMartin Czygan <martin.czygan@gmail.com>2021-07-04 10:52:47 +0200
commitd8404da85ed7bc2852edfdf83152a8611d95b406 (patch)
tree4e408fe55cebca3495b925283528cd59d4e8a70c
parent0935913a70d4fb12851f1c085c0e9dd6bb0cf5e8 (diff)
downloadrefcat-d8404da85ed7bc2852edfdf83152a8611d95b406.tar.gz
refcat-d8404da85ed7bc2852edfdf83152a8611d95b406.zip
zipkey: add batch test
-rw-r--r--skate/zipkey/batch.go10
-rw-r--r--skate/zipkey/batch_test.go57
2 files changed, 64 insertions, 3 deletions
diff --git a/skate/zipkey/batch.go b/skate/zipkey/batch.go
index d81897d..c31909c 100644
--- a/skate/zipkey/batch.go
+++ b/skate/zipkey/batch.go
@@ -14,6 +14,7 @@ type Batcher struct {
queue chan []*Group
wg sync.WaitGroup
err error
+ closing bool
}
// NewBatcher set ups a new Batcher.
@@ -32,6 +33,7 @@ func NewBatcher(gf groupFunc) *Batcher {
}
func (b *Batcher) Close() error {
+ b.closing = true
g := make([]*Group, len(b.batch))
copy(g, b.batch)
b.queue <- g
@@ -41,8 +43,11 @@ func (b *Batcher) Close() error {
return b.err
}
-// GroupFunc implement the groupFunc type.
+// GroupFunc implement the groupFunc type. Not thread safe.
func (b *Batcher) GroupFunc(g *Group) error {
+ if b.closing {
+ panic("cannot call GroupFunc after Close")
+ }
b.batch = append(b.batch, g)
if len(b.batch) == b.Size {
g := make([]*Group, len(b.batch))
@@ -59,8 +64,7 @@ func (b *Batcher) worker(queue chan []*Group, wg *sync.WaitGroup) {
OUTER:
for batch := range queue {
for _, g := range batch {
- err := b.gf(g)
- if err != nil {
+ if err := b.gf(g); err != nil {
b.err = err
break OUTER
}
diff --git a/skate/zipkey/batch_test.go b/skate/zipkey/batch_test.go
index 7c6a48c..38a1307 100644
--- a/skate/zipkey/batch_test.go
+++ b/skate/zipkey/batch_test.go
@@ -3,7 +3,11 @@ package zipkey
import (
"bytes"
"encoding/json"
+ "fmt"
+ "io"
+ "reflect"
"strings"
+ "sync"
"testing"
)
@@ -30,3 +34,56 @@ func TestBatcher(t *testing.T) {
t.Fatalf("got %v, want %v", got, want)
}
}
+
+func TestBatcherLarge(t *testing.T) {
+ var (
+ N = 1000000
+ numWorkers = 24
+ size = 7000
+ // We share a single writer across threads, so we need to guard each
+ // write. TODO: measure performance impact.
+ mu sync.Mutex
+ buf bytes.Buffer
+ f = func(g *Group) error {
+ var v string
+ if reflect.DeepEqual(g.G0, g.G1) {
+ v = "1"
+ } else {
+ v = "0"
+ }
+ mu.Lock()
+ defer mu.Unlock()
+ if _, err := io.WriteString(&buf, v); err != nil {
+ return err
+ }
+ return nil
+ }
+ b = NewBatcher(groupFunc(f))
+ )
+ b.Size = size
+ b.NumWorkers = numWorkers
+ for i := 0; i < N; i++ {
+ var u, v string
+ if i%2 == 0 {
+ u, v = "a", "b"
+ } else {
+ u, v = "a", "a"
+ }
+ g := &Group{
+ Key: fmt.Sprintf("%d", i),
+ G0: []string{u},
+ G1: []string{v},
+ }
+ if err := b.GroupFunc(g); err != nil {
+ t.Fatalf("unexpected err from gf: %v", err)
+ }
+ }
+ if err := b.Close(); err != nil {
+ t.Fatalf("unexpected err from close: %v", err)
+ }
+ got := buf.String()
+ count0, count1 := strings.Count(got, "0"), strings.Count(got, "1")
+ if count1 != N/2 {
+ t.Fatalf("got %v, want %v (count0=%v, buf=%s)", count1, N/2, count0, buf.String())
+ }
+}