From 6bb05e88506b5b09dd5d73d50235ebdb9cc34934 Mon Sep 17 00:00:00 2001 From: Martin Czygan Date: Sun, 4 Jul 2021 00:12:04 +0200 Subject: wip: batch reducers --- skate/zipkey/batch.go | 69 ++++++++++++++++++++++++++++++++++++++++++++++ skate/zipkey/batch_test.go | 32 +++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 skate/zipkey/batch.go create mode 100644 skate/zipkey/batch_test.go diff --git a/skate/zipkey/batch.go b/skate/zipkey/batch.go new file mode 100644 index 0000000..d81897d --- /dev/null +++ b/skate/zipkey/batch.go @@ -0,0 +1,69 @@ +package zipkey + +import ( + "runtime" + "sync" +) + +// Batcher runs reducers in parallel on batches of groups. +type Batcher struct { + Size int + NumWorkers int + gf groupFunc + batch []*Group + queue chan []*Group + wg sync.WaitGroup + err error +} + +// NewBatcher set ups a new Batcher. +func NewBatcher(gf groupFunc) *Batcher { + batcher := Batcher{ + gf: gf, + Size: 1000, + NumWorkers: runtime.NumCPU(), + queue: make(chan []*Group), + } + for i := 0; i < batcher.NumWorkers; i++ { + batcher.wg.Add(1) + go batcher.worker(batcher.queue, &batcher.wg) + } + return &batcher +} + +func (b *Batcher) Close() error { + g := make([]*Group, len(b.batch)) + copy(g, b.batch) + b.queue <- g + b.batch = nil + close(b.queue) + b.wg.Wait() + return b.err +} + +// GroupFunc implement the groupFunc type. +func (b *Batcher) GroupFunc(g *Group) error { + b.batch = append(b.batch, g) + if len(b.batch) == b.Size { + g := make([]*Group, len(b.batch)) + copy(g, b.batch) + b.queue <- g + b.batch = nil + } + return nil +} + +// worker will wind down after a first error encountered. +func (b *Batcher) worker(queue chan []*Group, wg *sync.WaitGroup) { + defer wg.Done() +OUTER: + for batch := range queue { + for _, g := range batch { + err := b.gf(g) + if err != nil { + b.err = err + break OUTER + } + } + } +} diff --git a/skate/zipkey/batch_test.go b/skate/zipkey/batch_test.go new file mode 100644 index 0000000..7c6a48c --- /dev/null +++ b/skate/zipkey/batch_test.go @@ -0,0 +1,32 @@ +package zipkey + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestBatcher(t *testing.T) { + var ( + buf bytes.Buffer + enc = json.NewEncoder(&buf) + f = func(g *Group) error { + return enc.Encode(g) + } + b = NewBatcher(groupFunc(f)) + ) + b.GroupFunc(&Group{ + Key: "K1", + G0: []string{"A"}, + G1: []string{"B"}, + }) + b.Close() + var ( + got = strings.TrimSpace(buf.String()) + want = `{"Key":"K1","G0":["A"],"G1":["B"]}` + ) + if got != want { + t.Fatalf("got %v, want %v", got, want) + } +} -- cgit v1.2.3