aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--skate/zipkey/batch.go69
-rw-r--r--skate/zipkey/batch_test.go32
2 files changed, 101 insertions, 0 deletions
diff --git a/skate/zipkey/batch.go b/skate/zipkey/batch.go
new file mode 100644
index 0000000..d81897d
--- /dev/null
+++ b/skate/zipkey/batch.go
@@ -0,0 +1,69 @@
+package zipkey
+
+import (
+ "runtime"
+ "sync"
+)
+
+// Batcher runs reducers in parallel on batches of groups.
+type Batcher struct {
+ Size int
+ NumWorkers int
+ gf groupFunc
+ batch []*Group
+ queue chan []*Group
+ wg sync.WaitGroup
+ err error
+}
+
+// NewBatcher set ups a new Batcher.
+func NewBatcher(gf groupFunc) *Batcher {
+ batcher := Batcher{
+ gf: gf,
+ Size: 1000,
+ NumWorkers: runtime.NumCPU(),
+ queue: make(chan []*Group),
+ }
+ for i := 0; i < batcher.NumWorkers; i++ {
+ batcher.wg.Add(1)
+ go batcher.worker(batcher.queue, &batcher.wg)
+ }
+ return &batcher
+}
+
+func (b *Batcher) Close() error {
+ g := make([]*Group, len(b.batch))
+ copy(g, b.batch)
+ b.queue <- g
+ b.batch = nil
+ close(b.queue)
+ b.wg.Wait()
+ return b.err
+}
+
+// GroupFunc implement the groupFunc type.
+func (b *Batcher) GroupFunc(g *Group) error {
+ b.batch = append(b.batch, g)
+ if len(b.batch) == b.Size {
+ g := make([]*Group, len(b.batch))
+ copy(g, b.batch)
+ b.queue <- g
+ b.batch = nil
+ }
+ return nil
+}
+
+// worker will wind down after a first error encountered.
+func (b *Batcher) worker(queue chan []*Group, wg *sync.WaitGroup) {
+ defer wg.Done()
+OUTER:
+ for batch := range queue {
+ for _, g := range batch {
+ err := b.gf(g)
+ if err != nil {
+ b.err = err
+ break OUTER
+ }
+ }
+ }
+}
diff --git a/skate/zipkey/batch_test.go b/skate/zipkey/batch_test.go
new file mode 100644
index 0000000..7c6a48c
--- /dev/null
+++ b/skate/zipkey/batch_test.go
@@ -0,0 +1,32 @@
+package zipkey
+
+import (
+ "bytes"
+ "encoding/json"
+ "strings"
+ "testing"
+)
+
+func TestBatcher(t *testing.T) {
+ var (
+ buf bytes.Buffer
+ enc = json.NewEncoder(&buf)
+ f = func(g *Group) error {
+ return enc.Encode(g)
+ }
+ b = NewBatcher(groupFunc(f))
+ )
+ b.GroupFunc(&Group{
+ Key: "K1",
+ G0: []string{"A"},
+ G1: []string{"B"},
+ })
+ b.Close()
+ var (
+ got = strings.TrimSpace(buf.String())
+ want = `{"Key":"K1","G0":["A"],"G1":["B"]}`
+ )
+ if got != want {
+ t.Fatalf("got %v, want %v", got, want)
+ }
+}